TEMP: override DataClass.ToDict()
This commit is contained in:
parent
53ef22d6b8
commit
72f4d4948e
1 changed files with 114 additions and 6 deletions
120
dataclass.py
120
dataclass.py
|
@ -2,23 +2,27 @@ from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from munch import Munch
|
from munch import Munch
|
||||||
from typing import ClassVar, Optional, Union, Mapping, Any, get_type_hints, get_origin, get_args, Iterable
|
from typing import ClassVar, Optional, Union, Mapping, Any, get_type_hints, get_origin, get_args, GenericAlias, Iterable
|
||||||
from types import UnionType
|
from types import UnionType
|
||||||
|
|
||||||
|
NoneType = type(None)
|
||||||
|
|
||||||
|
|
||||||
def munchclass(*args, init=False, **kwargs):
|
def munchclass(*args, init=False, **kwargs):
|
||||||
return dataclass(*args, init=init, slots=True, **kwargs)
|
return dataclass(*args, init=init, slots=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def resolve_type_hint(hint: type) -> Iterable[type]:
|
def resolve_type_hint(hint: type, ignore_origins: list[type] = []) -> Iterable[type]:
|
||||||
origin = get_origin(hint)
|
origin = get_origin(hint)
|
||||||
args: Iterable[type] = get_args(hint)
|
args: Iterable[type] = get_args(hint)
|
||||||
|
if origin in ignore_origins:
|
||||||
|
return [hint]
|
||||||
if origin is Optional:
|
if origin is Optional:
|
||||||
args = set(list(args) + [type(None)])
|
args = set(list(args) + [NoneType])
|
||||||
if origin in [Union, UnionType, Optional]:
|
if origin in [Union, UnionType, Optional]:
|
||||||
results: list[type] = []
|
results: list[type] = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
results += resolve_type_hint(arg)
|
results += resolve_type_hint(arg, ignore_origins=ignore_origins)
|
||||||
return results
|
return results
|
||||||
return [origin or hint]
|
return [origin or hint]
|
||||||
|
|
||||||
|
@ -26,6 +30,8 @@ def resolve_type_hint(hint: type) -> Iterable[type]:
|
||||||
class DataClass(Munch):
|
class DataClass(Munch):
|
||||||
|
|
||||||
_type_hints: ClassVar[dict[str, Any]]
|
_type_hints: ClassVar[dict[str, Any]]
|
||||||
|
_strip_hidden: ClassVar[bool] = False
|
||||||
|
_sparse: ClassVar[bool] = False
|
||||||
|
|
||||||
def __init__(self, d: dict = {}, validate: bool = True, **kwargs):
|
def __init__(self, d: dict = {}, validate: bool = True, **kwargs):
|
||||||
self.update(d | kwargs, validate=validate)
|
self.update(d | kwargs, validate=validate)
|
||||||
|
@ -39,7 +45,7 @@ class DataClass(Munch):
|
||||||
type_hints = cls._type_hints
|
type_hints = cls._type_hints
|
||||||
if key in type_hints:
|
if key in type_hints:
|
||||||
_classes = tuple[type](resolve_type_hint(type_hints[key]))
|
_classes = tuple[type](resolve_type_hint(type_hints[key]))
|
||||||
optional = type(None) in _classes
|
optional = NoneType in _classes
|
||||||
if issubclass(_classes[0], dict):
|
if issubclass(_classes[0], dict):
|
||||||
assert isinstance(value, dict) or optional
|
assert isinstance(value, dict) or optional
|
||||||
target_class = _classes[0]
|
target_class = _classes[0]
|
||||||
|
@ -89,7 +95,20 @@ class DataClass(Munch):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fromDict(cls, values: Mapping[str, Any], validate: bool = True):
|
def fromDict(cls, values: Mapping[str, Any], validate: bool = True):
|
||||||
return cls(**cls.transform(values, validate))
|
return cls(d=values, validate=validate)
|
||||||
|
|
||||||
|
def toDict(
|
||||||
|
self,
|
||||||
|
strip_hidden: Optional[bool] = None,
|
||||||
|
sparse: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
return strip_dict(
|
||||||
|
self,
|
||||||
|
hints=self._type_hints,
|
||||||
|
strip_hidden=self._strip_hidden if strip_hidden is None else strip_hidden,
|
||||||
|
sparse=self._sparse if sparse is None else sparse,
|
||||||
|
recursive=True,
|
||||||
|
)
|
||||||
|
|
||||||
def update(self, d: Mapping[str, Any], validate: bool = True):
|
def update(self, d: Mapping[str, Any], validate: bool = True):
|
||||||
Munch.update(self, type(self).transform(d, validate))
|
Munch.update(self, type(self).transform(d, validate))
|
||||||
|
@ -100,3 +119,92 @@ class DataClass(Munch):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{type(self)}{dict.__repr__(self.toDict())}'
|
return f'{type(self)}{dict.__repr__(self.toDict())}'
|
||||||
|
|
||||||
|
def toYaml(self, strip_hidden: bool = False, sparse: bool = False, **yaml_args) -> str:
|
||||||
|
import yaml
|
||||||
|
return yaml.dump(
|
||||||
|
self.toDict(strip_hidden=strip_hidden, sparse=sparse),
|
||||||
|
**yaml_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def toToml(self, strip_hidden: bool = False, sparse: bool = False, **toml_args) -> str:
|
||||||
|
import toml
|
||||||
|
return toml.dumps(
|
||||||
|
self.toDict(strip_hidden=strip_hidden, sparse=sparse),
|
||||||
|
**toml_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_hints(hints: Any) -> list[Any]:
|
||||||
|
if not isinstance(hints, (list, tuple)):
|
||||||
|
yield hints
|
||||||
|
return
|
||||||
|
for i in hints:
|
||||||
|
yield from flatten_hints(i)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_dict(
|
||||||
|
d: dict[Any, Any],
|
||||||
|
hints: dict[str, Any],
|
||||||
|
strip_hidden: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
recursive: bool = True,
|
||||||
|
) -> dict[Any, Any]:
|
||||||
|
result = dict(d)
|
||||||
|
if not (strip_hidden or sparse or result):
|
||||||
|
print(f"shortcircuiting {d=}")
|
||||||
|
return result
|
||||||
|
print(f"Stripping {result} with hints: {hints}")
|
||||||
|
for k, v in d.items():
|
||||||
|
if not isinstance(k, str):
|
||||||
|
print(f"skipping unknown key type {k=}")
|
||||||
|
continue
|
||||||
|
if strip_hidden and k.startswith('_'):
|
||||||
|
result.pop(k)
|
||||||
|
continue
|
||||||
|
if sparse and (v is None and NoneType in resolve_type_hint(hints.get(k, "abc"))):
|
||||||
|
print(f"popping empty {k}")
|
||||||
|
result.pop(k)
|
||||||
|
continue
|
||||||
|
if recursive and isinstance(v, dict):
|
||||||
|
if not v:
|
||||||
|
result[k] = {}
|
||||||
|
continue
|
||||||
|
if isinstance(v, DataClass):
|
||||||
|
print(f"Dataclass detected in {k=}")
|
||||||
|
result[k] = v.toDict(strip_hidden=strip_hidden, sparse=sparse)
|
||||||
|
continue
|
||||||
|
if isinstance(v, Munch):
|
||||||
|
print(f"Converting munch {k=}")
|
||||||
|
result[k] = v.toDict()
|
||||||
|
if k not in hints:
|
||||||
|
print(f"skipping unknown {k=}")
|
||||||
|
continue
|
||||||
|
print(f"STRIPPING RECURSIVELY: {k}: {v}, parent hints: {hints[k]}")
|
||||||
|
_subhints = {}
|
||||||
|
_hints = resolve_type_hint(hints[k], [dict])
|
||||||
|
hints_flat = list(flatten_hints(_hints))
|
||||||
|
print(f"going over hints for {k}: {_hints=} {hints_flat=}")
|
||||||
|
|
||||||
|
for hint in hints_flat:
|
||||||
|
print(f"working on hint: {hint}")
|
||||||
|
if get_origin(hint) == dict:
|
||||||
|
_valtype = get_args(hint)[1]
|
||||||
|
_subhints = {n: _valtype for n in v.keys()}
|
||||||
|
print(f"generated {_subhints=} from {_valtype=}")
|
||||||
|
break
|
||||||
|
if isinstance(hint, type) and issubclass(hint, DataClass):
|
||||||
|
_subhints = hint._type_hints
|
||||||
|
print(f"found subhints: {_subhints}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"ignoring {hint=}")
|
||||||
|
print(f"STRIPPING SUBDICT {k=} WITH {_subhints=}")
|
||||||
|
result[k] = strip_dict(
|
||||||
|
v,
|
||||||
|
hints=_subhints,
|
||||||
|
sparse=sparse,
|
||||||
|
strip_hidden=strip_hidden,
|
||||||
|
recursive=recursive,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue