Skip to content

Fix/freeze dict keys #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 23 additions & 50 deletions pycardano/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class RawCBOR:
CBORTag,
set,
frozenset,
frozendict,
FrozenList,
IndefiniteFrozenList
]

PRIMITIVE_TYPES = (
Expand All @@ -117,6 +120,9 @@ class RawCBOR:
CBORTag,
set,
frozenset,
frozendict,
FrozenList,
IndefiniteFrozenList
)
"""
A list of types that could be encoded by
Expand Down Expand Up @@ -155,7 +161,7 @@ def default_encoder(
encoder: CBOREncoder, value: Union[CBORSerializable, IndefiniteList]
):
"""A fallback function that encodes CBORSerializable to CBOR"""
assert isinstance(value, (CBORSerializable, IndefiniteList, RawCBOR, FrozenList)), (
assert isinstance(value, (CBORSerializable, IndefiniteList, RawCBOR, FrozenList, frozendict)), (
f"Type of input value is not CBORSerializable, " f"got {type(value)} instead."
)
if isinstance(value, IndefiniteList):
Expand All @@ -170,6 +176,8 @@ def default_encoder(
encoder.write(value.cbor)
elif isinstance(value, FrozenList):
encoder.encode(list(value))
elif isinstance(value, frozendict):
encoder.encode(dict(value))
else:
encoder.encode(value.to_validated_primitive())

Expand Down Expand Up @@ -222,66 +230,31 @@ def to_primitive(self) -> Primitive:
CBOR primitive types.
"""
result = self.to_shallow_primitive()
container_types = (
dict,
OrderedDict,
defaultdict,
set,
frozenset,
tuple,
list,
CBORTag,
IndefiniteList,
)

def _helper(value):
def _dfs(value):
if isinstance(value, CBORSerializable):
return value.to_primitive()
elif isinstance(value, container_types):
return _dfs(value)
else:
return value

def _freeze(value):
if isinstance(value, (dict, OrderedDict, defaultdict)):
return frozendict({k: _freeze(v) for k, v in value.items()})
elif isinstance(value, frozenset) or isinstance(value, set):
return frozenset(value)
elif isinstance(value, tuple):
return tuple([_freeze(k) for k in value])
elif isinstance(value, IndefiniteList):
fl = IndefiniteFrozenList([_freeze(k) for k in value])
fl.freeze()
return fl
elif isinstance(value, list):
fl = FrozenList([_freeze(k) for k in value])
fl.freeze()
return fl
elif isinstance(value, CBORTag):
return CBORTag(value.tag, _freeze(value.value))
else:
return value

def _dfs(value):
if isinstance(value, (dict, OrderedDict, defaultdict)):
elif isinstance(value, (dict, OrderedDict, defaultdict)):
new_result = type(value)()
if hasattr(value, "default_factory"):
new_result.setdefault(value.default_factory)
for k, v in value.items():
new_result[_freeze(_helper(k))] = _helper(v)
return new_result
new_result[_dfs(k)] = _dfs(v)
return frozendict(new_result)
elif isinstance(value, set):
return {_freeze(_helper(v)) for v in value}
elif isinstance(value, frozenset):
return frozenset({_freeze(_helper(v)) for v in value})
return frozenset(_dfs(v) for v in value)
elif isinstance(value, tuple):
return tuple([_helper(k) for k in value])
return tuple([_dfs(k) for k in value])
elif isinstance(value, list):
return [_helper(k) for k in value]
fl = FrozenList([_dfs(k) for k in value])
fl.freeze()
return fl
elif isinstance(value, IndefiniteList):
return IndefiniteList([_helper(k) for k in value])
fl = IndefiniteFrozenList([_dfs(k) for k in value])
fl.freeze()
return fl
elif isinstance(value, CBORTag):
return CBORTag(value.tag, _helper(value.value))
return CBORTag(value.tag, _dfs(value.value))
else:
return value

Expand All @@ -307,7 +280,7 @@ def _check_recursive(value, type_hint):
return _check_recursive(value, type_hint.__args__[0])
elif origin is Union:
return any(_check_recursive(value, arg) for arg in type_hint.__args__)
elif origin is Dict or isinstance(value, dict):
elif origin is Dict or isinstance(value, (dict, frozendict)):
key_type, value_type = type_hint.__args__
return all(
_check_recursive(k, key_type) and _check_recursive(v, value_type)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ cose = "0.9.dev8"
pprintpp = "^0.4.0"
mnemonic = "^0.20"
ECPy = "^1.2.5"
frozendict = "^2.3.7"
frozendict = "^2.3.4"
frozenlist = "^1.3.3"

[tool.poetry.dev-dependencies]
Expand Down