diff --git a/dataclass.py b/dataclass.py index facaacd..4dd9948 100644 --- a/dataclass.py +++ b/dataclass.py @@ -2,23 +2,27 @@ from __future__ import annotations from dataclasses import dataclass from munch import Munch -from typing import ClassVar, Optional, Union, Mapping, Any, get_type_hints, get_origin, get_args, Iterable +from typing import ClassVar, Optional, Union, Mapping, Any, get_type_hints, get_origin, get_args, GenericAlias, Iterable from types import UnionType +NoneType = type(None) + def munchclass(*args, init=False, **kwargs): return dataclass(*args, init=init, slots=True, **kwargs) -def resolve_type_hint(hint: type) -> Iterable[type]: +def resolve_type_hint(hint: type, ignore_origins: list[type] = []) -> Iterable[type]: origin = get_origin(hint) args: Iterable[type] = get_args(hint) + if origin in ignore_origins: + return [hint] if origin is Optional: - args = set(list(args) + [type(None)]) + args = set(list(args) + [NoneType]) if origin in [Union, UnionType, Optional]: results: list[type] = [] for arg in args: - results += resolve_type_hint(arg) + results += resolve_type_hint(arg, ignore_origins=ignore_origins) return results return [origin or hint] @@ -26,6 +30,8 @@ def resolve_type_hint(hint: type) -> Iterable[type]: class DataClass(Munch): _type_hints: ClassVar[dict[str, Any]] + _strip_hidden: ClassVar[bool] = False + _sparse: ClassVar[bool] = False def __init__(self, d: dict = {}, validate: bool = True, **kwargs): self.update(d | kwargs, validate=validate) @@ -39,7 +45,7 @@ class DataClass(Munch): type_hints = cls._type_hints if key in type_hints: _classes = tuple[type](resolve_type_hint(type_hints[key])) - optional = type(None) in _classes + optional = NoneType in _classes if issubclass(_classes[0], dict): assert isinstance(value, dict) or optional target_class = _classes[0] @@ -89,7 +95,20 @@ class DataClass(Munch): @classmethod def fromDict(cls, values: Mapping[str, Any], validate: bool = True): - return cls(**cls.transform(values, validate)) + return cls(d=values, validate=validate) + + def toDict( + self, + strip_hidden: Optional[bool] = None, + sparse: Optional[bool] = None, + ): + return strip_dict( + self, + hints=self._type_hints, + strip_hidden=self._strip_hidden if strip_hidden is None else strip_hidden, + sparse=self._sparse if sparse is None else sparse, + recursive=True, + ) def update(self, d: Mapping[str, Any], validate: bool = True): Munch.update(self, type(self).transform(d, validate)) @@ -100,3 +119,92 @@ class DataClass(Munch): def __repr__(self): return f'{type(self)}{dict.__repr__(self.toDict())}' + + def toYaml(self, strip_hidden: bool = False, sparse: bool = False, **yaml_args) -> str: + import yaml + return yaml.dump( + self.toDict(strip_hidden=strip_hidden, sparse=sparse), + **yaml_args, + ) + + def toToml(self, strip_hidden: bool = False, sparse: bool = False, **toml_args) -> str: + import toml + return toml.dumps( + self.toDict(strip_hidden=strip_hidden, sparse=sparse), + **toml_args, + ) + + +def flatten_hints(hints: Any) -> list[Any]: + if not isinstance(hints, (list, tuple)): + yield hints + return + for i in hints: + yield from flatten_hints(i) + + +def strip_dict( + d: dict[Any, Any], + hints: dict[str, Any], + strip_hidden: bool = False, + sparse: bool = False, + recursive: bool = True, +) -> dict[Any, Any]: + result = dict(d) + if not (strip_hidden or sparse or result): + print(f"shortcircuiting {d=}") + return result + print(f"Stripping {result} with hints: {hints}") + for k, v in d.items(): + if not isinstance(k, str): + print(f"skipping unknown key type {k=}") + continue + if strip_hidden and k.startswith('_'): + result.pop(k) + continue + if sparse and (v is None and NoneType in resolve_type_hint(hints.get(k, "abc"))): + print(f"popping empty {k}") + result.pop(k) + continue + if recursive and isinstance(v, dict): + if not v: + result[k] = {} + continue + if isinstance(v, DataClass): + print(f"Dataclass detected in {k=}") + result[k] = v.toDict(strip_hidden=strip_hidden, sparse=sparse) + continue + if isinstance(v, Munch): + print(f"Converting munch {k=}") + result[k] = v.toDict() + if k not in hints: + print(f"skipping unknown {k=}") + continue + print(f"STRIPPING RECURSIVELY: {k}: {v}, parent hints: {hints[k]}") + _subhints = {} + _hints = resolve_type_hint(hints[k], [dict]) + hints_flat = list(flatten_hints(_hints)) + print(f"going over hints for {k}: {_hints=} {hints_flat=}") + + for hint in hints_flat: + print(f"working on hint: {hint}") + if get_origin(hint) == dict: + _valtype = get_args(hint)[1] + _subhints = {n: _valtype for n in v.keys()} + print(f"generated {_subhints=} from {_valtype=}") + break + if isinstance(hint, type) and issubclass(hint, DataClass): + _subhints = hint._type_hints + print(f"found subhints: {_subhints}") + break + else: + print(f"ignoring {hint=}") + print(f"STRIPPING SUBDICT {k=} WITH {_subhints=}") + result[k] = strip_dict( + v, + hints=_subhints, + sparse=sparse, + strip_hidden=strip_hidden, + recursive=recursive, + ) + return result