Skip to content

Fix static typing #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pycardano/certificate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Union
from typing import Optional, Union

from pycardano.hash import PoolKeyHash, ScriptHash, VerificationKeyHash
from pycardano.serialization import ArrayCBORSerializable
Expand All @@ -16,7 +16,7 @@
@dataclass(repr=False)
class StakeCredential(ArrayCBORSerializable):

_CODE: int = field(init=False, default=None)
_CODE: Optional[int] = field(init=False, default=None)

credential: Union[VerificationKeyHash, ScriptHash]

Expand Down
11 changes: 10 additions & 1 deletion pycardano/cip/cip8.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ def verify(
if attach_cose_key:
# The cose key is attached as a dict object which contains the verification key
# the headers of the signature are emtpy
assert isinstance(
signed_message, dict
), "signed_message must be a dict if attach_cose_key is True"
key = signed_message.get("key")
signed_message = signed_message.get("signature")
signed_message = signed_message.get("signature") # type: ignore

else:
key = "" # key will be extracted later from the payload headers

# Add back the "D2" header byte and decode
assert isinstance(
signed_message, str
), "signed_message must be a hex string at this point"
decoded_message = CoseMessage.decode(bytes.fromhex("d2" + signed_message))

# generate/extract the cose key
Expand All @@ -146,6 +152,9 @@ def verify(

else:
# i,e key is sent separately
assert isinstance(
key, str
), "key must be a hex string if attach_cose_key is True"
cose_key = CoseKey.decode(bytes.fromhex(key))
verification_key = cose_key[OKPKpX]

Expand Down
30 changes: 19 additions & 11 deletions pycardano/coinselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def select(
utxos: List[UTxO],
outputs: List[TransactionOutput],
context: ChainContext,
max_input_count: int = None,
include_max_fee: bool = True,
respect_min_utxo: bool = True,
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
) -> Tuple[List[UTxO], Value]:
"""From an input list of UTxOs, select a subset of UTxOs whose sum (including ADA and multi-assets)
is equal to or larger than the sum of a set of outputs.
Expand Down Expand Up @@ -115,7 +115,11 @@ def select(
if change.coin < min_change_amount:
additional, _ = self.select(
available,
[TransactionOutput(None, min_change_amount - change.coin)],
[
TransactionOutput(
_FAKE_ADDR, Value(min_change_amount - change.coin)
)
],
context,
max_input_count - len(selected) if max_input_count else None,
include_max_fee=False,
Expand Down Expand Up @@ -230,13 +234,13 @@ def _improve(
remaining: List[UTxO],
ideal: Value,
upper_bound: Value,
max_input_count: int,
max_input_count: Optional[int] = None,
):
if not remaining or self._find_diff_by_former(ideal, selected_amount) <= 0:
# In case where there is no remaining UTxOs or we already selected more than ideal,
# we cannot improve by randomly adding more UTxOs, therefore return immediate.
return
if max_input_count and len(selected) > max_input_count:
if max_input_count is not None and len(selected) > max_input_count:
raise MaxInputCountExceededException(
f"Max input count: {max_input_count} exceeded!"
)
Expand Down Expand Up @@ -269,9 +273,9 @@ def select(
utxos: List[UTxO],
outputs: List[TransactionOutput],
context: ChainContext,
max_input_count: int = None,
include_max_fee: bool = True,
respect_min_utxo: bool = True,
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
) -> Tuple[List[UTxO], Value]:
# Shallow copy the list
remaining = list(utxos)
Expand All @@ -284,7 +288,7 @@ def select(
request_sorted = sorted(assets, key=self._get_single_asset_val, reverse=True)

# Phase 1 - random select
selected = []
selected: List[UTxO] = []
selected_amount = Value()
for r in request_sorted:
self._random_select_subset(r, remaining, selected, selected_amount)
Expand Down Expand Up @@ -321,7 +325,11 @@ def select(
if change.coin < min_change_amount:
additional, _ = self.select(
remaining,
[TransactionOutput(None, min_change_amount - change.coin)],
[
TransactionOutput(
_FAKE_ADDR, Value(min_change_amount - change.coin)
)
],
context,
max_input_count - len(selected) if max_input_count else None,
include_max_fee=False,
Expand Down
37 changes: 23 additions & 14 deletions pycardano/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import os
from typing import Type
from typing import Optional, Type

from nacl.encoding import RawEncoder
from nacl.hash import blake2b
Expand Down Expand Up @@ -41,7 +41,12 @@ class Key(CBORSerializable):
KEY_TYPE = ""
DESCRIPTION = ""

def __init__(self, payload: bytes, key_type: str = None, description: str = None):
def __init__(
self,
payload: bytes,
key_type: Optional[str] = None,
description: Optional[str] = None,
):
self._payload = payload
self._key_type = key_type or self.KEY_TYPE
self._description = description or self.KEY_TYPE
Expand Down Expand Up @@ -83,7 +88,7 @@ def to_json(self) -> str:
)

@classmethod
def from_json(cls, data: str, validate_type=False) -> Key:
def from_json(cls: Type[Key], data: str, validate_type=False) -> Key:
"""Restore a key from a JSON string.

Args:
Expand All @@ -105,8 +110,12 @@ def from_json(cls, data: str, validate_type=False) -> Key:
f"Expect key type: {cls.KEY_TYPE}, got {obj['type']} instead."
)

k = cls.from_cbor(obj["cborHex"])

assert isinstance(k, cls)

return cls(
cls.from_cbor(obj["cborHex"]).payload,
k.payload,
key_type=obj["type"],
description=obj["description"],
)
Expand Down Expand Up @@ -244,19 +253,19 @@ class PaymentExtendedVerificationKey(ExtendedVerificationKey):


class PaymentKeyPair:
def __init__(
self, signing_key: PaymentSigningKey, verification_key: PaymentVerificationKey
):
def __init__(self, signing_key: SigningKey, verification_key: VerificationKey):
self.signing_key = signing_key
self.verification_key = verification_key

@classmethod
def generate(cls) -> PaymentKeyPair:
def generate(cls: Type[PaymentKeyPair]) -> PaymentKeyPair:
signing_key = PaymentSigningKey.generate()
return cls.from_signing_key(signing_key)

@classmethod
def from_signing_key(cls, signing_key: PaymentSigningKey) -> PaymentKeyPair:
def from_signing_key(
cls: Type[PaymentKeyPair], signing_key: SigningKey
) -> PaymentKeyPair:
return cls(signing_key, PaymentVerificationKey.from_signing_key(signing_key))

def __eq__(self, other):
Expand Down Expand Up @@ -288,17 +297,17 @@ class StakeExtendedVerificationKey(ExtendedVerificationKey):


class StakeKeyPair:
def __init__(
self, signing_key: StakeSigningKey, verification_key: StakeVerificationKey
):
def __init__(self, signing_key: SigningKey, verification_key: VerificationKey):
self.signing_key = signing_key
self.verification_key = verification_key

@classmethod
def generate(cls) -> StakeKeyPair:
def generate(cls: Type[StakeKeyPair]) -> StakeKeyPair:
signing_key = StakeSigningKey.generate()
return cls.from_signing_key(signing_key)

@classmethod
def from_signing_key(cls, signing_key: StakeSigningKey) -> StakeKeyPair:
def from_signing_key(
cls: Type[StakeKeyPair], signing_key: SigningKey
) -> StakeKeyPair:
return cls(signing_key, StakeVerificationKey.from_signing_key(signing_key))
24 changes: 13 additions & 11 deletions pycardano/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, ClassVar, List, Type, Union
from typing import Any, ClassVar, List, Optional, Type, Union

from cbor2 import CBORTag
from nacl.encoding import RawEncoder
Expand All @@ -20,7 +20,7 @@
list_hook,
)

__all__ = ["Metadata", "ShellayMarryMetadata", "AlonzoMetadata", "AuxiliaryData"]
__all__ = ["Metadata", "ShelleyMarryMetadata", "AlonzoMetadata", "AuxiliaryData"]


class Metadata(DictCBORSerializable):
Expand Down Expand Up @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs):


@dataclass
class ShellayMarryMetadata(ArrayCBORSerializable):
class ShelleyMarryMetadata(ArrayCBORSerializable):
metadata: Metadata
native_scripts: List[NativeScript] = field(
native_scripts: Optional[List[NativeScript]] = field(
default=None, metadata={"object_hook": list_hook(NativeScript)}
)

Expand All @@ -79,12 +79,14 @@ class ShellayMarryMetadata(ArrayCBORSerializable):
class AlonzoMetadata(MapCBORSerializable):
TAG: ClassVar[int] = 259

metadata: Metadata = field(default=None, metadata={"optional": True, "key": 0})
native_scripts: List[NativeScript] = field(
metadata: Optional[Metadata] = field(
default=None, metadata={"optional": True, "key": 0}
)
native_scripts: Optional[List[NativeScript]] = field(
default=None,
metadata={"optional": True, "key": 1, "object_hook": list_hook(NativeScript)},
)
plutus_scripts: List[bytes] = field(
plutus_scripts: Optional[List[bytes]] = field(
default=None, metadata={"optional": True, "key": 2}
)

Expand All @@ -107,23 +109,23 @@ def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:

@dataclass
class AuxiliaryData(CBORSerializable):
data: Union[Metadata, ShellayMarryMetadata, AlonzoMetadata]
data: Union[Metadata, ShelleyMarryMetadata, AlonzoMetadata]

def to_primitive(self) -> Primitive:
return self.data.to_primitive()

@classmethod
def from_primitive(cls: Type[AuxiliaryData], value: Primitive) -> AuxiliaryData:
for t in [AlonzoMetadata, ShellayMarryMetadata, Metadata]:
for t in [AlonzoMetadata, ShelleyMarryMetadata, Metadata]:
# The schema of metadata in different eras are mutually exclusive, so we can try deserializing
# them one by one without worrying about mismatch.
try:
return AuxiliaryData(t.from_primitive(value))
return AuxiliaryData(t.from_primitive(value)) # type: ignore
except DeserializeException:
pass
raise DeserializeException(f"Couldn't parse auxiliary data: {value}")

def hash(self) -> AuxiliaryDataHash:
return AuxiliaryDataHash(
blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder)
blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder) # type: ignore
)
36 changes: 20 additions & 16 deletions pycardano/plutus.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CBORSerializable,
DictCBORSerializable,
IndefiniteList,
Primitive,
RawCBOR,
default_encoder,
limit_primitive_type,
Expand All @@ -39,6 +40,7 @@
"PlutusV2Script",
"RawPlutusData",
"Redeemer",
"ScriptType",
"datum_hash",
"plutus_script_hash",
"script_hash",
Expand Down Expand Up @@ -471,7 +473,7 @@ def __post_init__(self):
)

def to_shallow_primitive(self) -> CBORTag:
primitives = super().to_shallow_primitive()
primitives: Primitive = super().to_shallow_primitive()
if primitives:
primitives = IndefiniteList(primitives)
tag = get_tag(self.CONSTR_ID)
Expand Down Expand Up @@ -544,7 +546,7 @@ def _dfs(obj):
return json.dumps(_dfs(self), **kwargs)

@classmethod
def from_dict(cls: PlutusData, data: dict) -> PlutusData:
def from_dict(cls: Type[PlutusData], data: dict) -> PlutusData:
"""Convert a dictionary to PlutusData

Args:
Expand Down Expand Up @@ -606,7 +608,7 @@ def _dfs(obj):
return _dfs(data)

@classmethod
def from_json(cls: PlutusData, data: str) -> PlutusData:
def from_json(cls: Type[PlutusData], data: str) -> PlutusData:
"""Restore a json encoded string to a PlutusData.

Args:
Expand Down Expand Up @@ -701,7 +703,7 @@ class Redeemer(ArrayCBORSerializable):

data: Any

ex_units: ExecutionUnits = None
ex_units: Optional[ExecutionUnits] = None

@classmethod
@limit_primitive_type(list)
Expand Down Expand Up @@ -729,13 +731,23 @@ def plutus_script_hash(
return script_hash(script)


def script_hash(
script: Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]
) -> ScriptHash:
class PlutusV1Script(bytes):
pass


class PlutusV2Script(bytes):
pass


ScriptType = Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]
"""Script type. A Union type that contains all valid script types."""


def script_hash(script: ScriptType) -> ScriptHash:
"""Calculates the hash of a script, which could be either native script or plutus script.

Args:
script (Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]): A script.
script (ScriptType): A script.

Returns:
ScriptHash: blake2b hash of the script.
Expand All @@ -752,11 +764,3 @@ def script_hash(
)
else:
raise TypeError(f"Unexpected script type: {type(script)}")


class PlutusV1Script(bytes):
pass


class PlutusV2Script(bytes):
pass
Loading