Skip to content

Commit f58234e

Browse files
committed
Use TypeVars for return types of RedisModel and its subtype's methods
1 parent 721f734 commit f58234e

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

aredis_om/model/model.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
model_registry = {}
4949
_T = TypeVar("_T")
50+
Model = TypeVar("Model", bound="RedisModel")
5051
log = logging.getLogger(__name__)
5152
escaper = TokenEscaper()
5253

@@ -1152,16 +1153,16 @@ async def delete(
11521153
return await cls._delete(db, cls.make_primary_key(pk))
11531154

11541155
@classmethod
1155-
async def get(cls, pk: Any) -> "RedisModel":
1156+
async def get(cls: Type["Model"], pk: Any) -> "Model":
11561157
raise NotImplementedError
11571158

11581159
async def update(self, **field_values):
11591160
"""Update this model instance with the specified key-value pairs."""
11601161
raise NotImplementedError
11611162

11621163
async def save(
1163-
self, pipeline: Optional[redis.client.Pipeline] = None
1164-
) -> "RedisModel":
1164+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1165+
) -> "Model":
11651166
raise NotImplementedError
11661167

11671168
async def expire(
@@ -1258,11 +1259,11 @@ def get_annotations(cls):
12581259

12591260
@classmethod
12601261
async def add(
1261-
cls,
1262-
models: Sequence["RedisModel"],
1262+
cls: Type["Model"],
1263+
models: Sequence["Model"],
12631264
pipeline: Optional[redis.client.Pipeline] = None,
12641265
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
1265-
) -> Sequence["RedisModel"]:
1266+
) -> Sequence["Model"]:
12661267
db = cls._get_db(pipeline, bulk=True)
12671268

12681269
for model in models:
@@ -1337,8 +1338,8 @@ def __init_subclass__(cls, **kwargs):
13371338
)
13381339

13391340
async def save(
1340-
self, pipeline: Optional[redis.client.Pipeline] = None
1341-
) -> "HashModel":
1341+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1342+
) -> "Model":
13421343
self.check()
13431344
db = self._get_db(pipeline)
13441345

@@ -1364,7 +1365,7 @@ async def all_pks(cls): # type: ignore
13641365
)
13651366

13661367
@classmethod
1367-
async def get(cls, pk: Any) -> "HashModel":
1368+
async def get(cls: Type["Model"], pk: Any) -> "Model":
13681369
document = await cls.db().hgetall(cls.make_primary_key(pk))
13691370
if not document:
13701371
raise NotFoundError
@@ -1509,8 +1510,8 @@ def __init__(self, *args, **kwargs):
15091510
super().__init__(*args, **kwargs)
15101511

15111512
async def save(
1512-
self, pipeline: Optional[redis.client.Pipeline] = None
1513-
) -> "JsonModel":
1513+
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
1514+
) -> "Model":
15141515
self.check()
15151516
db = self._get_db(pipeline)
15161517

@@ -1559,7 +1560,7 @@ async def update(self, **field_values):
15591560
await self.save()
15601561

15611562
@classmethod
1562-
async def get(cls, pk: Any) -> "JsonModel":
1563+
async def get(cls: Type["Model"], pk: Any) -> "Model":
15631564
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
15641565
if document == "null":
15651566
raise NotFoundError

0 commit comments

Comments
 (0)