utils: handle zstd compressed tarfiles
This commit is contained in:
parent
30c3fa77fd
commit
c576dc8a51
1 changed files with 75 additions and 9 deletions
80
utils.py
80
utils.py
|
@ -11,6 +11,7 @@ import subprocess
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
from dateutil.parser import parse as parsedate
|
from dateutil.parser import parse as parsedate
|
||||||
|
from io import BytesIO
|
||||||
from shutil import which
|
from shutil import which
|
||||||
from typing import Any, Generator, IO, Optional, Union, Sequence
|
from typing import Any, Generator, IO, Optional, Union, Sequence
|
||||||
|
|
||||||
|
@ -151,6 +152,63 @@ def is_zstd(data):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class BackwardsReadableStream:
|
||||||
|
def __init__(self, stream):
|
||||||
|
self.stream = stream
|
||||||
|
self.buffer = bytearray()
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
def read(self, size=-1):
|
||||||
|
data = b''
|
||||||
|
if size == -1:
|
||||||
|
# read all remaining data in stream
|
||||||
|
data = self.stream.read()
|
||||||
|
else:
|
||||||
|
not_read = (self.position + size) - len(self.buffer)
|
||||||
|
if not_read > 0:
|
||||||
|
# read up to size bytes from stream
|
||||||
|
data = self.stream.read(not_read)
|
||||||
|
else:
|
||||||
|
data = self.buffer[self.position:self.position+size+1]
|
||||||
|
|
||||||
|
old_position = self.position
|
||||||
|
new_position = self.position + len(data)
|
||||||
|
self.buffer.extend(data)
|
||||||
|
self.position = new_position
|
||||||
|
return self.buffer[old_position:new_position+1]
|
||||||
|
|
||||||
|
def seek(self, offset, whence=0):
|
||||||
|
if whence == 0:
|
||||||
|
# seek from beginning of buffer
|
||||||
|
self.position = offset
|
||||||
|
elif whence == 1:
|
||||||
|
# seek from current position
|
||||||
|
self.position += offset
|
||||||
|
elif whence == 2:
|
||||||
|
# seek from end of buffer
|
||||||
|
self.position = len(self.buffer) + offset
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid whence value")
|
||||||
|
|
||||||
|
# adjust position to be within buffer bounds
|
||||||
|
self.position = max(0, min(self.position, len(self.buffer)))
|
||||||
|
|
||||||
|
def tell(self):
|
||||||
|
return self.position
|
||||||
|
|
||||||
|
def readable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def seekable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.stream.__exit__(exc_type, exc_value, traceback)
|
||||||
|
|
||||||
|
|
||||||
def decompress_if_zstd(stream):
|
def decompress_if_zstd(stream):
|
||||||
"""
|
"""
|
||||||
Given a byte stream, returns either the original stream or the decompressed stream
|
Given a byte stream, returns either the original stream or the decompressed stream
|
||||||
|
@ -165,7 +223,7 @@ def decompress_if_zstd(stream):
|
||||||
logging.debug(f"Decompressing {stream=}")
|
logging.debug(f"Decompressing {stream=}")
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
dctx = zstd.ZstdDecompressor()
|
dctx = zstd.ZstdDecompressor()
|
||||||
return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|')
|
return tarfile.open(fileobj=BytesIO(dctx.stream_reader(stream).read()), mode='r:tar')
|
||||||
|
|
||||||
|
|
||||||
def open_tar(tar_file: str) -> tarfile.TarFile:
|
def open_tar(tar_file: str) -> tarfile.TarFile:
|
||||||
|
@ -190,13 +248,21 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl
|
||||||
paths = [f"{p.strip('/')}/" for p in paths]
|
paths = [f"{p.strip('/')}/" for p in paths]
|
||||||
with open_tar(tar_file) as index:
|
with open_tar(tar_file) as index:
|
||||||
for member in index.getmembers():
|
for member in index.getmembers():
|
||||||
for path in paths:
|
file_path = member.path
|
||||||
if member.isfile() and member.path.startswith(path):
|
if member.isfile() and check_file_matches(file_path, paths):
|
||||||
|
logging.debug(f"tar: Returning {file_path}")
|
||||||
fd = index.extractfile(member)
|
fd = index.extractfile(member)
|
||||||
assert fd
|
assert fd
|
||||||
yield member.path, fd
|
yield file_path, fd
|
||||||
break
|
else:
|
||||||
continue
|
logging.debug(f'tar: unmatched {file_path} for query {paths}')
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_matches(file_path: str, queries: list[str]) -> bool:
|
||||||
|
for query in queries:
|
||||||
|
if file_path.startswith(query):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def extract_files_from_tar_generator(
|
def extract_files_from_tar_generator(
|
||||||
|
@ -205,7 +271,6 @@ def extract_files_from_tar_generator(
|
||||||
remove_prefix: str = '',
|
remove_prefix: str = '',
|
||||||
append_slash: bool = True,
|
append_slash: bool = True,
|
||||||
):
|
):
|
||||||
assert os.path.exists(output_dir)
|
|
||||||
remove_prefix = remove_prefix.strip('/')
|
remove_prefix = remove_prefix.strip('/')
|
||||||
if append_slash and remove_prefix:
|
if append_slash and remove_prefix:
|
||||||
remove_prefix += '/'
|
remove_prefix += '/'
|
||||||
|
@ -214,6 +279,7 @@ def extract_files_from_tar_generator(
|
||||||
output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/'))
|
output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/'))
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
with open(output_path, 'wb') as f:
|
with open(output_path, 'wb') as f:
|
||||||
|
logging.debug(f"Extracting {file_path}")
|
||||||
f.write(fd.read())
|
f.write(fd.read())
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue