utils: add decompress_if_zstd
This commit is contained in:
parent
38edce080f
commit
a982f8c966
1 changed files with 45 additions and 2 deletions
47
utils.py
47
utils.py
|
@ -129,9 +129,52 @@ def get_gid(group: Union[int, str]) -> int:
|
|||
return grp.getgrnam(group).gr_gid
|
||||
|
||||
|
||||
def is_zstd(data):
|
||||
"""
|
||||
Returns True if the given byte stream is compressed with the zstd algorithm,
|
||||
False otherwise. This function performs a simplified version of the actual zstd
|
||||
header validation, using hardcoded values.
|
||||
"""
|
||||
# Check for the magic number at the beginning of the stream
|
||||
if len(data) < 4 or data[:4] != b"\x28\xb5\x2f\xfd":
|
||||
logging.debug("zstd header not found")
|
||||
return False
|
||||
# Check the frame descriptor block size
|
||||
if len(data) < 8:
|
||||
return False
|
||||
frame_size = data[4] & 0x7F | (data[5] & 0x7F) << 7 | (data[6] & 0x7F) << 14 | (data[7] & 0x07) << 21
|
||||
if frame_size < 1 or frame_size > 1 << 31:
|
||||
return False
|
||||
# Check the frame descriptor block for the checksum
|
||||
if len(data) < 18:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def decompress_if_zstd(stream):
|
||||
"""
|
||||
Given a byte stream, returns either the original stream or the decompressed stream
|
||||
if it is compressed with the zstd algorithm.
|
||||
"""
|
||||
if isinstance(stream, str):
|
||||
stream = open(stream, 'rb')
|
||||
data = stream.peek(18)[:18]
|
||||
if not is_zstd(data):
|
||||
logging.debug(f"{data=} Not zstd, skipping")
|
||||
return tarfile.open(fileobj=stream)
|
||||
logging.debug(f"Decompressing {stream=}")
|
||||
import zstandard as zstd
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|')
|
||||
|
||||
|
||||
def open_tar(tar_file: str) -> tarfile.TarFile:
|
||||
return decompress_if_zstd(tar_file)
|
||||
|
||||
|
||||
def read_files_from_tar(tar_file: str, files: Sequence[str]) -> Generator[tuple[str, IO], None, None]:
|
||||
assert os.path.exists(tar_file)
|
||||
with tarfile.open(tar_file) as index:
|
||||
with open_tar(tar_file) as index:
|
||||
for path in files:
|
||||
fd = index.extractfile(index.getmember(path))
|
||||
assert fd
|
||||
|
@ -145,7 +188,7 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl
|
|||
"""
|
||||
assert os.path.exists(tar_file)
|
||||
paths = [f"{p.strip('/')}/" for p in paths]
|
||||
with tarfile.open(tar_file) as index:
|
||||
with open_tar(tar_file) as index:
|
||||
for member in index.getmembers():
|
||||
for path in paths:
|
||||
if member.isfile() and member.path.startswith(path):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue