From c576dc8a517908c694c0ed19b430f927ae70c889 Mon Sep 17 00:00:00 2001 From: InsanePrawn Date: Mon, 17 Apr 2023 21:57:30 +0200 Subject: [PATCH] utils: handle zstd compressed tarfiles --- utils.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 75 insertions(+), 9 deletions(-) diff --git a/utils.py b/utils.py index 0cfdae8..26a65ed 100644 --- a/utils.py +++ b/utils.py @@ -11,6 +11,7 @@ import subprocess import tarfile from dateutil.parser import parse as parsedate +from io import BytesIO from shutil import which from typing import Any, Generator, IO, Optional, Union, Sequence @@ -151,6 +152,63 @@ def is_zstd(data): return True +class BackwardsReadableStream: + def __init__(self, stream): + self.stream = stream + self.buffer = bytearray() + self.position = 0 + + def read(self, size=-1): + data = b'' + if size == -1: + # read all remaining data in stream + data = self.stream.read() + else: + not_read = (self.position + size) - len(self.buffer) + if not_read > 0: + # read up to size bytes from stream + data = self.stream.read(not_read) + else: + data = self.buffer[self.position:self.position+size+1] + + old_position = self.position + new_position = self.position + len(data) + self.buffer.extend(data) + self.position = new_position + return self.buffer[old_position:new_position+1] + + def seek(self, offset, whence=0): + if whence == 0: + # seek from beginning of buffer + self.position = offset + elif whence == 1: + # seek from current position + self.position += offset + elif whence == 2: + # seek from end of buffer + self.position = len(self.buffer) + offset + else: + raise ValueError("Invalid whence value") + + # adjust position to be within buffer bounds + self.position = max(0, min(self.position, len(self.buffer))) + + def tell(self): + return self.position + + def readable(self): + return True + + def seekable(self): + return True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stream.__exit__(exc_type, exc_value, traceback) + + def decompress_if_zstd(stream): """ Given a byte stream, returns either the original stream or the decompressed stream @@ -165,7 +223,7 @@ def decompress_if_zstd(stream): logging.debug(f"Decompressing {stream=}") import zstandard as zstd dctx = zstd.ZstdDecompressor() - return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|') + return tarfile.open(fileobj=BytesIO(dctx.stream_reader(stream).read()), mode='r:tar') def open_tar(tar_file: str) -> tarfile.TarFile: @@ -190,13 +248,21 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl paths = [f"{p.strip('/')}/" for p in paths] with open_tar(tar_file) as index: for member in index.getmembers(): - for path in paths: - if member.isfile() and member.path.startswith(path): - fd = index.extractfile(member) - assert fd - yield member.path, fd - break - continue + file_path = member.path + if member.isfile() and check_file_matches(file_path, paths): + logging.debug(f"tar: Returning {file_path}") + fd = index.extractfile(member) + assert fd + yield file_path, fd + else: + logging.debug(f'tar: unmatched {file_path} for query {paths}') + + +def check_file_matches(file_path: str, queries: list[str]) -> bool: + for query in queries: + if file_path.startswith(query): + return True + return False def extract_files_from_tar_generator( @@ -205,7 +271,6 @@ def extract_files_from_tar_generator( remove_prefix: str = '', append_slash: bool = True, ): - assert os.path.exists(output_dir) remove_prefix = remove_prefix.strip('/') if append_slash and remove_prefix: remove_prefix += '/' @@ -214,6 +279,7 @@ def extract_files_from_tar_generator( output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/')) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: + logging.debug(f"Extracting {file_path}") f.write(fd.read())