47
47
48
48
model_registry = {}
49
49
_T = TypeVar ("_T" )
50
+ Model = TypeVar ("Model" , bound = "RedisModel" )
50
51
log = logging .getLogger (__name__ )
51
52
escaper = TokenEscaper ()
52
53
@@ -1152,16 +1153,16 @@ async def delete(
1152
1153
return await cls ._delete (db , cls .make_primary_key (pk ))
1153
1154
1154
1155
@classmethod
1155
- async def get (cls , pk : Any ) -> "RedisModel " :
1156
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1156
1157
raise NotImplementedError
1157
1158
1158
1159
async def update (self , ** field_values ):
1159
1160
"""Update this model instance with the specified key-value pairs."""
1160
1161
raise NotImplementedError
1161
1162
1162
1163
async def save (
1163
- self , pipeline : Optional [redis .client .Pipeline ] = None
1164
- ) -> "RedisModel " :
1164
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1165
+ ) -> "Model " :
1165
1166
raise NotImplementedError
1166
1167
1167
1168
async def expire (
@@ -1258,11 +1259,11 @@ def get_annotations(cls):
1258
1259
1259
1260
@classmethod
1260
1261
async def add (
1261
- cls ,
1262
- models : Sequence ["RedisModel " ],
1262
+ cls : Type [ "Model" ] ,
1263
+ models : Sequence ["Model " ],
1263
1264
pipeline : Optional [redis .client .Pipeline ] = None ,
1264
1265
pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1265
- ) -> Sequence ["RedisModel " ]:
1266
+ ) -> Sequence ["Model " ]:
1266
1267
db = cls ._get_db (pipeline , bulk = True )
1267
1268
1268
1269
for model in models :
@@ -1337,8 +1338,8 @@ def __init_subclass__(cls, **kwargs):
1337
1338
)
1338
1339
1339
1340
async def save (
1340
- self , pipeline : Optional [redis .client .Pipeline ] = None
1341
- ) -> "HashModel " :
1341
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1342
+ ) -> "Model " :
1342
1343
self .check ()
1343
1344
db = self ._get_db (pipeline )
1344
1345
@@ -1364,7 +1365,7 @@ async def all_pks(cls): # type: ignore
1364
1365
)
1365
1366
1366
1367
@classmethod
1367
- async def get (cls , pk : Any ) -> "HashModel " :
1368
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1368
1369
document = await cls .db ().hgetall (cls .make_primary_key (pk ))
1369
1370
if not document :
1370
1371
raise NotFoundError
@@ -1509,8 +1510,8 @@ def __init__(self, *args, **kwargs):
1509
1510
super ().__init__ (* args , ** kwargs )
1510
1511
1511
1512
async def save (
1512
- self , pipeline : Optional [redis .client .Pipeline ] = None
1513
- ) -> "JsonModel " :
1513
+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1514
+ ) -> "Model " :
1514
1515
self .check ()
1515
1516
db = self ._get_db (pipeline )
1516
1517
@@ -1559,7 +1560,7 @@ async def update(self, **field_values):
1559
1560
await self .save ()
1560
1561
1561
1562
@classmethod
1562
- async def get (cls , pk : Any ) -> "JsonModel " :
1563
+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
1563
1564
document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
1564
1565
if document == "null" :
1565
1566
raise NotFoundError
0 commit comments