diff --git a/utils.py b/utils.py index 02048ac..0cfdae8 100644 --- a/utils.py +++ b/utils.py @@ -129,9 +129,52 @@ def get_gid(group: Union[int, str]) -> int: return grp.getgrnam(group).gr_gid +def is_zstd(data): + """ + Returns True if the given byte stream is compressed with the zstd algorithm, + False otherwise. This function performs a simplified version of the actual zstd + header validation, using hardcoded values. + """ + # Check for the magic number at the beginning of the stream + if len(data) < 4 or data[:4] != b"\x28\xb5\x2f\xfd": + logging.debug("zstd header not found") + return False + # Check the frame descriptor block size + if len(data) < 8: + return False + frame_size = data[4] & 0x7F | (data[5] & 0x7F) << 7 | (data[6] & 0x7F) << 14 | (data[7] & 0x07) << 21 + if frame_size < 1 or frame_size > 1 << 31: + return False + # Check the frame descriptor block for the checksum + if len(data) < 18: + return False + return True + + +def decompress_if_zstd(stream): + """ + Given a byte stream, returns either the original stream or the decompressed stream + if it is compressed with the zstd algorithm. + """ + if isinstance(stream, str): + stream = open(stream, 'rb') + data = stream.peek(18)[:18] + if not is_zstd(data): + logging.debug(f"{data=} Not zstd, skipping") + return tarfile.open(fileobj=stream) + logging.debug(f"Decompressing {stream=}") + import zstandard as zstd + dctx = zstd.ZstdDecompressor() + return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|') + + +def open_tar(tar_file: str) -> tarfile.TarFile: + return decompress_if_zstd(tar_file) + + def read_files_from_tar(tar_file: str, files: Sequence[str]) -> Generator[tuple[str, IO], None, None]: assert os.path.exists(tar_file) - with tarfile.open(tar_file) as index: + with open_tar(tar_file) as index: for path in files: fd = index.extractfile(index.getmember(path)) assert fd @@ -145,7 +188,7 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl """ assert os.path.exists(tar_file) paths = [f"{p.strip('/')}/" for p in paths] - with tarfile.open(tar_file) as index: + with open_tar(tar_file) as index: for member in index.getmembers(): for path in paths: if member.isfile() and member.path.startswith(path):