utils: handle zstd compressed tarfiles

This commit is contained in:
InsanePrawn 2023-04-17 21:57:30 +02:00
parent 30c3fa77fd
commit c576dc8a51

View file

@ -11,6 +11,7 @@ import subprocess
import tarfile import tarfile
from dateutil.parser import parse as parsedate from dateutil.parser import parse as parsedate
from io import BytesIO
from shutil import which from shutil import which
from typing import Any, Generator, IO, Optional, Union, Sequence from typing import Any, Generator, IO, Optional, Union, Sequence
@ -151,6 +152,63 @@ def is_zstd(data):
return True return True
class BackwardsReadableStream:
def __init__(self, stream):
self.stream = stream
self.buffer = bytearray()
self.position = 0
def read(self, size=-1):
data = b''
if size == -1:
# read all remaining data in stream
data = self.stream.read()
else:
not_read = (self.position + size) - len(self.buffer)
if not_read > 0:
# read up to size bytes from stream
data = self.stream.read(not_read)
else:
data = self.buffer[self.position:self.position+size+1]
old_position = self.position
new_position = self.position + len(data)
self.buffer.extend(data)
self.position = new_position
return self.buffer[old_position:new_position+1]
def seek(self, offset, whence=0):
if whence == 0:
# seek from beginning of buffer
self.position = offset
elif whence == 1:
# seek from current position
self.position += offset
elif whence == 2:
# seek from end of buffer
self.position = len(self.buffer) + offset
else:
raise ValueError("Invalid whence value")
# adjust position to be within buffer bounds
self.position = max(0, min(self.position, len(self.buffer)))
def tell(self):
return self.position
def readable(self):
return True
def seekable(self):
return True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stream.__exit__(exc_type, exc_value, traceback)
def decompress_if_zstd(stream): def decompress_if_zstd(stream):
""" """
Given a byte stream, returns either the original stream or the decompressed stream Given a byte stream, returns either the original stream or the decompressed stream
@ -165,7 +223,7 @@ def decompress_if_zstd(stream):
logging.debug(f"Decompressing {stream=}") logging.debug(f"Decompressing {stream=}")
import zstandard as zstd import zstandard as zstd
dctx = zstd.ZstdDecompressor() dctx = zstd.ZstdDecompressor()
return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|') return tarfile.open(fileobj=BytesIO(dctx.stream_reader(stream).read()), mode='r:tar')
def open_tar(tar_file: str) -> tarfile.TarFile: def open_tar(tar_file: str) -> tarfile.TarFile:
@ -190,13 +248,21 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl
paths = [f"{p.strip('/')}/" for p in paths] paths = [f"{p.strip('/')}/" for p in paths]
with open_tar(tar_file) as index: with open_tar(tar_file) as index:
for member in index.getmembers(): for member in index.getmembers():
for path in paths: file_path = member.path
if member.isfile() and member.path.startswith(path): if member.isfile() and check_file_matches(file_path, paths):
logging.debug(f"tar: Returning {file_path}")
fd = index.extractfile(member) fd = index.extractfile(member)
assert fd assert fd
yield member.path, fd yield file_path, fd
break else:
continue logging.debug(f'tar: unmatched {file_path} for query {paths}')
def check_file_matches(file_path: str, queries: list[str]) -> bool:
for query in queries:
if file_path.startswith(query):
return True
return False
def extract_files_from_tar_generator( def extract_files_from_tar_generator(
@ -205,7 +271,6 @@ def extract_files_from_tar_generator(
remove_prefix: str = '', remove_prefix: str = '',
append_slash: bool = True, append_slash: bool = True,
): ):
assert os.path.exists(output_dir)
remove_prefix = remove_prefix.strip('/') remove_prefix = remove_prefix.strip('/')
if append_slash and remove_prefix: if append_slash and remove_prefix:
remove_prefix += '/' remove_prefix += '/'
@ -214,6 +279,7 @@ def extract_files_from_tar_generator(
output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/')) output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/'))
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f: with open(output_path, 'wb') as f:
logging.debug(f"Extracting {file_path}")
f.write(fd.read()) f.write(fd.read())