diff --git a/chroot/device.py b/chroot/device.py index 24bda89..02897a2 100644 --- a/chroot/device.py +++ b/chroot/device.py @@ -5,6 +5,7 @@ from typing import Optional from config import config from constants import Arch, BASE_PACKAGES +from distro.repo import RepoInfo from distro.distro import get_kupfer_local, get_kupfer_https from exec.file import get_temp_dir, makedir, root_makedir from utils import check_findmnt @@ -56,11 +57,12 @@ def get_device_chroot( arch: Arch, packages: list[str] = BASE_PACKAGES, use_local_repos: bool = True, - extra_repos: Optional[dict] = None, + extra_repos: Optional[dict[str, RepoInfo]] = None, **kwargs, ) -> DeviceChroot: name = f'rootfs_{device}-{flavour}' - repos = dict(get_kupfer_local(arch).repos if use_local_repos else get_kupfer_https(arch).repos) + repos: dict[str, RepoInfo] = get_kupfer_local(arch).repos if use_local_repos else get_kupfer_https(arch).repos # type: ignore + repos.update(extra_repos or {}) default = DeviceChroot(name, arch, initialize=False, copy_base=False, base_packages=packages, extra_repos=repos) diff --git a/distro/distro.py b/distro/distro.py index 9651d65..d4cb009 100644 --- a/distro/distro.py +++ b/distro/distro.py @@ -1,23 +1,24 @@ -from typing import Optional, Mapping +from typing import Generic, Mapping, Optional, TypeVar from constants import Arch, ARCHES, BASE_DISTROS, REPOSITORIES, KUPFER_HTTPS, CHROOT_PATHS from generator import generate_pacman_conf_body from config import config -from .package import BinaryPackage -from .repo import RepoInfo, Repo +from .repo import BinaryPackageType, RepoInfo, Repo, LocalRepo, RemoteRepo + +RepoType = TypeVar('RepoType', bound=Repo) -class Distro: - repos: Mapping[str, Repo] +class Distro(Generic[RepoType]): + repos: Mapping[str, RepoType] arch: str def __init__(self, arch: Arch, repo_infos: dict[str, RepoInfo], scan=False): assert (arch in ARCHES) self.arch = arch - self.repos = dict[str, Repo]() + self.repos = dict[str, RepoType]() for repo_name, repo_info in repo_infos.items(): - self.repos[repo_name] = Repo( + self.repos[repo_name] = self._create_repo( name=repo_name, arch=arch, url_template=repo_info.url_template, @@ -25,16 +26,22 @@ class Distro: scan=scan, ) - def get_packages(self) -> dict[str, BinaryPackage]: + def _create_repo(self, **kwargs) -> RepoType: + raise NotImplementedError() + Repo(**kwargs) + + def get_packages(self) -> dict[str, BinaryPackageType]: """ get packages from all repos, semantically overlaying them""" - results = dict[str, BinaryPackage]() + results = dict[str, BinaryPackageType]() for repo in list(self.repos.values())[::-1]: assert repo.packages is not None results.update(repo.packages) return results def repos_config_snippet(self, extra_repos: Mapping[str, RepoInfo] = {}) -> str: - extras = [Repo(name, url_template=info.url_template, arch=self.arch, options=info.options, scan=False) for name, info in extra_repos.items()] + extras: list[Repo] = [ + Repo(name, url_template=info.url_template, arch=self.arch, options=info.options, scan=False) for name, info in extra_repos.items() + ] return '\n\n'.join(repo.config_snippet() for repo in (extras + list(self.repos.values()))) def get_pacman_conf(self, extra_repos: Mapping[str, RepoInfo] = {}, check_space: bool = True, in_chroot: bool = True): @@ -53,43 +60,63 @@ class Distro: return True -def get_base_distro(arch: str) -> Distro: +class LocalDistro(Distro[LocalRepo]): + + def _create_repo(self, **kwargs) -> LocalRepo: + return LocalRepo(**kwargs) + + +class RemoteDistro(Distro[RemoteRepo]): + + def _create_repo(self, **kwargs) -> RemoteRepo: + return RemoteRepo(**kwargs) + + +def get_base_distro(arch: str) -> RemoteDistro: repos = {name: RepoInfo(url_template=url) for name, url in BASE_DISTROS[arch]['repos'].items()} - return Distro(arch=arch, repo_infos=repos, scan=False) + return RemoteDistro(arch=arch, repo_infos=repos, scan=False) def get_kupfer(arch: str, url_template: str, scan: bool = False) -> Distro: repos = {name: RepoInfo(url_template=url_template, options={'SigLevel': 'Never'}) for name in REPOSITORIES} - return Distro( + remote = not url_template.startswith('file://') + clss = RemoteDistro if remote else LocalDistro + distro = clss( arch=arch, repo_infos=repos, scan=scan, ) + assert isinstance(distro, (LocalDistro, RemoteDistro)) + return distro -_kupfer_https = dict[Arch, Distro]() -_kupfer_local = dict[Arch, Distro]() -_kupfer_local_chroots = dict[Arch, Distro]() +_kupfer_https = dict[Arch, RemoteDistro]() +_kupfer_local = dict[Arch, LocalDistro]() +_kupfer_local_chroots = dict[Arch, LocalDistro]() -def get_kupfer_https(arch: Arch, scan: bool = False) -> Distro: +def get_kupfer_https(arch: Arch, scan: bool = False) -> RemoteDistro: global _kupfer_https if arch not in _kupfer_https or not _kupfer_https[arch]: - _kupfer_https[arch] = get_kupfer(arch, KUPFER_HTTPS.replace('%branch%', config.file.pacman.repo_branch), scan) + kupfer = get_kupfer(arch, KUPFER_HTTPS.replace('%branch%', config.file.pacman.repo_branch), scan) + assert isinstance(kupfer, RemoteDistro) + _kupfer_https[arch] = kupfer item = _kupfer_https[arch] if scan and not item.is_scanned(): item.scan() return item -def get_kupfer_local(arch: Optional[Arch] = None, in_chroot: bool = True, scan: bool = False) -> Distro: +def get_kupfer_local(arch: Optional[Arch] = None, in_chroot: bool = True, scan: bool = False) -> LocalDistro: global _kupfer_local, _kupfer_local_chroots cache = _kupfer_local_chroots if in_chroot else _kupfer_local arch = arch or config.runtime.arch assert arch if arch not in cache or not cache[arch]: dir = CHROOT_PATHS['packages'] if in_chroot else config.get_path('packages') - cache[arch] = get_kupfer(arch, f"file://{dir}/$arch/$repo") + kupfer = get_kupfer(arch, f"file://{dir}/$arch/$repo") + assert isinstance(kupfer, LocalDistro) + cache[arch] = kupfer item = cache[arch] if scan and not item.is_scanned(): item.scan() diff --git a/distro/package.py b/distro/package.py index 4dfb7a4..e08502b 100644 --- a/distro/package.py +++ b/distro/package.py @@ -34,3 +34,14 @@ class BinaryPackage(PackageInfo): for key, value in zip(pruned_lines[0::2], pruned_lines[1::2]): desc[key.strip()] = value.strip() return clss(desc['NAME'], desc['VERSION'], desc['FILENAME'], resolved_url='/'.join([resolved_repo_url, desc['FILENAME']])) + + def acquire(self) -> str: + raise NotImplementedError() + + +class LocalPackage(BinaryPackage): + pass + + +class RemotePackage(BinaryPackage): + pass diff --git a/distro/repo.py b/distro/repo.py index 64b53bb..c96a617 100644 --- a/distro/repo.py +++ b/distro/repo.py @@ -5,7 +5,11 @@ import tarfile import tempfile import urllib.request -from .package import BinaryPackage +from typing import Generic, TypeVar + +from .package import BinaryPackage, LocalPackage, RemotePackage + +BinaryPackageType = TypeVar('BinaryPackageType', bound=BinaryPackage) def resolve_url(url_template, repo_name: str, arch: str): @@ -24,11 +28,11 @@ class RepoInfo: self.options.update(options) -class Repo(RepoInfo): +class Repo(RepoInfo, Generic[BinaryPackageType]): name: str resolved_url: str arch: str - packages: dict[str, BinaryPackage] + packages: dict[str, BinaryPackageType] remote: bool scanned: bool = False @@ -38,26 +42,28 @@ class Repo(RepoInfo): def scan(self): self.resolved_url = self.resolve_url() self.remote = not self.resolved_url.startswith('file://') - uri = f'{self.resolved_url}/{self.name}.db' - path = '' - if self.remote: - logging.info(f'Downloading repo file from {uri}') - with urllib.request.urlopen(uri) as request: - fd, path = tempfile.mkstemp() - with open(fd, 'wb') as writable: - writable.write(request.read()) - else: - path = uri.split('file://')[1] + path = self.acquire_db_file() logging.debug(f'Parsing repo file at {path}') with tarfile.open(path) as index: for node in index.getmembers(): if os.path.basename(node.name) == 'desc': logging.debug(f'Parsing desc file for {os.path.dirname(node.name)}') - pkg = BinaryPackage.parse_desc(index.extractfile(node).read().decode(), self.resolved_url) + fd = index.extractfile(node) + assert fd + pkg = self._parse_desc(fd.read().decode()) self.packages[pkg.name] = pkg self.scanned = True + def _parse_desc(self, desc_text: str): # can't annotate the type properly :( + raise NotImplementedError() + + def parse_desc(self, desc_text: str) -> BinaryPackageType: + return self._parse_desc(desc_text) + + def acquire_db_file(self) -> str: + raise NotImplementedError + def __init__(self, name: str, url_template: str, arch: str, options={}, scan=False): self.packages = {} self.name = name @@ -76,3 +82,27 @@ class Repo(RepoInfo): def get_RepoInfo(self): return RepoInfo(url_template=self.url_template, options=self.options) + + +class LocalRepo(Repo[LocalPackage]): + + def _parse_desc(self, desc_text: str) -> LocalPackage: + return LocalPackage.parse_desc(desc_text, resolved_repo_url=self.resolved_url) + + def acquire_db_file(self) -> str: + return f'{self.resolved_url}/{self.name}.db'.split('file://')[1] + + +class RemoteRepo(Repo[RemotePackage]): + + def _parse_desc(self, desc_text: str) -> RemotePackage: + return RemotePackage.parse_desc(desc_text, resolved_repo_url=self.resolved_url) + + def acquire_db_file(self) -> str: + uri = f'{self.resolved_url}/{self.name}.db' + logging.info(f'Downloading repo file from {uri}') + with urllib.request.urlopen(uri) as request: + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as writable: + writable.write(request.read()) + return path diff --git a/packages/cli.py b/packages/cli.py index 39953dc..893d9af 100644 --- a/packages/cli.py +++ b/packages/cli.py @@ -9,6 +9,7 @@ from config import config from constants import Arch, ARCHES, REPOSITORIES from exec.file import remove_file from distro.distro import get_kupfer_local +from distro.package import LocalPackage from ssh import run_ssh_command, scp_put_files from utils import git from wrapper import check_programs_wrap, enforce_wrap @@ -85,11 +86,8 @@ def cmd_sideload(paths: Iterable[str], arch: Optional[Arch] = None, no_build: bo arch = arch or get_profile_device(hint_or_set_arch=True).arch if not no_build: build(paths, False, arch=arch, try_download=True) - files = [ - pkg.resolved_url.split('file://')[1] - for pkg in get_kupfer_local(arch=arch, scan=True, in_chroot=False).get_packages().values() - if pkg.resolved_url and pkg.name in paths - ] + repo: dict[str, LocalPackage] = get_kupfer_local(arch=arch, scan=True, in_chroot=False).get_packages() + files = [pkg.resolved_url.split('file://')[1] for pkg in repo.values() if pkg.resolved_url and pkg.name in paths] logging.debug(f"Sideload: Found package files: {files}") if not files: logging.fatal("No packages matched")