utils: add decompress_if_zstd

This commit is contained in:
InsanePrawn 2023-04-17 18:56:11 +02:00
parent 38edce080f
commit a982f8c966

View file

@ -129,9 +129,52 @@ def get_gid(group: Union[int, str]) -> int:
return grp.getgrnam(group).gr_gid
def is_zstd(data):
"""
Returns True if the given byte stream is compressed with the zstd algorithm,
False otherwise. This function performs a simplified version of the actual zstd
header validation, using hardcoded values.
"""
# Check for the magic number at the beginning of the stream
if len(data) < 4 or data[:4] != b"\x28\xb5\x2f\xfd":
logging.debug("zstd header not found")
return False
# Check the frame descriptor block size
if len(data) < 8:
return False
frame_size = data[4] & 0x7F | (data[5] & 0x7F) << 7 | (data[6] & 0x7F) << 14 | (data[7] & 0x07) << 21
if frame_size < 1 or frame_size > 1 << 31:
return False
# Check the frame descriptor block for the checksum
if len(data) < 18:
return False
return True
def decompress_if_zstd(stream):
"""
Given a byte stream, returns either the original stream or the decompressed stream
if it is compressed with the zstd algorithm.
"""
if isinstance(stream, str):
stream = open(stream, 'rb')
data = stream.peek(18)[:18]
if not is_zstd(data):
logging.debug(f"{data=} Not zstd, skipping")
return tarfile.open(fileobj=stream)
logging.debug(f"Decompressing {stream=}")
import zstandard as zstd
dctx = zstd.ZstdDecompressor()
return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|')
def open_tar(tar_file: str) -> tarfile.TarFile:
return decompress_if_zstd(tar_file)
def read_files_from_tar(tar_file: str, files: Sequence[str]) -> Generator[tuple[str, IO], None, None]:
assert os.path.exists(tar_file)
with tarfile.open(tar_file) as index:
with open_tar(tar_file) as index:
for path in files:
fd = index.extractfile(index.getmember(path))
assert fd
@ -145,7 +188,7 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl
"""
assert os.path.exists(tar_file)
paths = [f"{p.strip('/')}/" for p in paths]
with tarfile.open(tar_file) as index:
with open_tar(tar_file) as index:
for member in index.getmembers():
for path in paths:
if member.isfile() and member.path.startswith(path):