utils: handle zstd compressed tarfiles

This commit is contained in:
InsanePrawn 2023-04-17 21:57:30 +02:00
parent 30c3fa77fd
commit c576dc8a51

View file

@ -11,6 +11,7 @@ import subprocess
import tarfile
from dateutil.parser import parse as parsedate
from io import BytesIO
from shutil import which
from typing import Any, Generator, IO, Optional, Union, Sequence
@ -151,6 +152,63 @@ def is_zstd(data):
return True
class BackwardsReadableStream:
def __init__(self, stream):
self.stream = stream
self.buffer = bytearray()
self.position = 0
def read(self, size=-1):
data = b''
if size == -1:
# read all remaining data in stream
data = self.stream.read()
else:
not_read = (self.position + size) - len(self.buffer)
if not_read > 0:
# read up to size bytes from stream
data = self.stream.read(not_read)
else:
data = self.buffer[self.position:self.position+size+1]
old_position = self.position
new_position = self.position + len(data)
self.buffer.extend(data)
self.position = new_position
return self.buffer[old_position:new_position+1]
def seek(self, offset, whence=0):
if whence == 0:
# seek from beginning of buffer
self.position = offset
elif whence == 1:
# seek from current position
self.position += offset
elif whence == 2:
# seek from end of buffer
self.position = len(self.buffer) + offset
else:
raise ValueError("Invalid whence value")
# adjust position to be within buffer bounds
self.position = max(0, min(self.position, len(self.buffer)))
def tell(self):
return self.position
def readable(self):
return True
def seekable(self):
return True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stream.__exit__(exc_type, exc_value, traceback)
def decompress_if_zstd(stream):
"""
Given a byte stream, returns either the original stream or the decompressed stream
@ -165,7 +223,7 @@ def decompress_if_zstd(stream):
logging.debug(f"Decompressing {stream=}")
import zstandard as zstd
dctx = zstd.ZstdDecompressor()
return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|')
return tarfile.open(fileobj=BytesIO(dctx.stream_reader(stream).read()), mode='r:tar')
def open_tar(tar_file: str) -> tarfile.TarFile:
@ -190,13 +248,21 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl
paths = [f"{p.strip('/')}/" for p in paths]
with open_tar(tar_file) as index:
for member in index.getmembers():
for path in paths:
if member.isfile() and member.path.startswith(path):
fd = index.extractfile(member)
assert fd
yield member.path, fd
break
continue
file_path = member.path
if member.isfile() and check_file_matches(file_path, paths):
logging.debug(f"tar: Returning {file_path}")
fd = index.extractfile(member)
assert fd
yield file_path, fd
else:
logging.debug(f'tar: unmatched {file_path} for query {paths}')
def check_file_matches(file_path: str, queries: list[str]) -> bool:
for query in queries:
if file_path.startswith(query):
return True
return False
def extract_files_from_tar_generator(
@ -205,7 +271,6 @@ def extract_files_from_tar_generator(
remove_prefix: str = '',
append_slash: bool = True,
):
assert os.path.exists(output_dir)
remove_prefix = remove_prefix.strip('/')
if append_slash and remove_prefix:
remove_prefix += '/'
@ -214,6 +279,7 @@ def extract_files_from_tar_generator(
output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/'))
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
logging.debug(f"Extracting {file_path}")
f.write(fd.read())