mirror of
https://gitlab.com/kupfer/kupferbootstrap.git
synced 2025-02-23 05:35:44 -05:00
utils: handle zstd compressed tarfiles
This commit is contained in:
parent
30c3fa77fd
commit
c576dc8a51
1 changed files with 75 additions and 9 deletions
84
utils.py
84
utils.py
|
@ -11,6 +11,7 @@ import subprocess
|
|||
import tarfile
|
||||
|
||||
from dateutil.parser import parse as parsedate
|
||||
from io import BytesIO
|
||||
from shutil import which
|
||||
from typing import Any, Generator, IO, Optional, Union, Sequence
|
||||
|
||||
|
@ -151,6 +152,63 @@ def is_zstd(data):
|
|||
return True
|
||||
|
||||
|
||||
class BackwardsReadableStream:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
self.buffer = bytearray()
|
||||
self.position = 0
|
||||
|
||||
def read(self, size=-1):
|
||||
data = b''
|
||||
if size == -1:
|
||||
# read all remaining data in stream
|
||||
data = self.stream.read()
|
||||
else:
|
||||
not_read = (self.position + size) - len(self.buffer)
|
||||
if not_read > 0:
|
||||
# read up to size bytes from stream
|
||||
data = self.stream.read(not_read)
|
||||
else:
|
||||
data = self.buffer[self.position:self.position+size+1]
|
||||
|
||||
old_position = self.position
|
||||
new_position = self.position + len(data)
|
||||
self.buffer.extend(data)
|
||||
self.position = new_position
|
||||
return self.buffer[old_position:new_position+1]
|
||||
|
||||
def seek(self, offset, whence=0):
|
||||
if whence == 0:
|
||||
# seek from beginning of buffer
|
||||
self.position = offset
|
||||
elif whence == 1:
|
||||
# seek from current position
|
||||
self.position += offset
|
||||
elif whence == 2:
|
||||
# seek from end of buffer
|
||||
self.position = len(self.buffer) + offset
|
||||
else:
|
||||
raise ValueError("Invalid whence value")
|
||||
|
||||
# adjust position to be within buffer bounds
|
||||
self.position = max(0, min(self.position, len(self.buffer)))
|
||||
|
||||
def tell(self):
|
||||
return self.position
|
||||
|
||||
def readable(self):
|
||||
return True
|
||||
|
||||
def seekable(self):
|
||||
return True
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.stream.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def decompress_if_zstd(stream):
|
||||
"""
|
||||
Given a byte stream, returns either the original stream or the decompressed stream
|
||||
|
@ -165,7 +223,7 @@ def decompress_if_zstd(stream):
|
|||
logging.debug(f"Decompressing {stream=}")
|
||||
import zstandard as zstd
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
return tarfile.open(fileobj=dctx.stream_reader(stream, read_size=4096), mode='r|')
|
||||
return tarfile.open(fileobj=BytesIO(dctx.stream_reader(stream).read()), mode='r:tar')
|
||||
|
||||
|
||||
def open_tar(tar_file: str) -> tarfile.TarFile:
|
||||
|
@ -190,13 +248,21 @@ def read_files_from_tar_recursive(tar_file: str, paths: Sequence[str], append_sl
|
|||
paths = [f"{p.strip('/')}/" for p in paths]
|
||||
with open_tar(tar_file) as index:
|
||||
for member in index.getmembers():
|
||||
for path in paths:
|
||||
if member.isfile() and member.path.startswith(path):
|
||||
fd = index.extractfile(member)
|
||||
assert fd
|
||||
yield member.path, fd
|
||||
break
|
||||
continue
|
||||
file_path = member.path
|
||||
if member.isfile() and check_file_matches(file_path, paths):
|
||||
logging.debug(f"tar: Returning {file_path}")
|
||||
fd = index.extractfile(member)
|
||||
assert fd
|
||||
yield file_path, fd
|
||||
else:
|
||||
logging.debug(f'tar: unmatched {file_path} for query {paths}')
|
||||
|
||||
|
||||
def check_file_matches(file_path: str, queries: list[str]) -> bool:
|
||||
for query in queries:
|
||||
if file_path.startswith(query):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_files_from_tar_generator(
|
||||
|
@ -205,7 +271,6 @@ def extract_files_from_tar_generator(
|
|||
remove_prefix: str = '',
|
||||
append_slash: bool = True,
|
||||
):
|
||||
assert os.path.exists(output_dir)
|
||||
remove_prefix = remove_prefix.strip('/')
|
||||
if append_slash and remove_prefix:
|
||||
remove_prefix += '/'
|
||||
|
@ -214,6 +279,7 @@ def extract_files_from_tar_generator(
|
|||
output_path = os.path.join(output_dir, file_path[len(remove_prefix):].lstrip('/'))
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, 'wb') as f:
|
||||
logging.debug(f"Extracting {file_path}")
|
||||
f.write(fd.read())
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue