@@ -1098,26 +1098,36 @@ def supported_usage_metadata_details(self) -> dict:
1098
1098
)
1099
1099
if "cache_read_input" in self .supported_usage_metadata_details ["invoke" ]:
1100
1100
msg = self .invoke_with_cache_read_input ()
1101
- assert (usage_metadata := msg .usage_metadata ) is not None
1102
- assert (
1103
- input_token_details := usage_metadata .get ("input_token_details" )
1104
- ) is not None
1105
- assert isinstance (input_token_details .get ("cache_read" ), int )
1101
+ usage_metadata = msg .usage_metadata
1102
+ assert usage_metadata is not None
1103
+ input_token_details = usage_metadata .get ("input_token_details" )
1104
+ assert input_token_details is not None
1105
+ cache_read_tokens = input_token_details .get ("cache_read" )
1106
+ assert isinstance (cache_read_tokens , int )
1107
+ assert cache_read_tokens >= 0
1106
1108
# Asserts that total input tokens are at least the sum of the token counts
1107
- assert usage_metadata . get ( "input_tokens" , 0 ) > = sum (
1108
- v for v in input_token_details .values () if isinstance (v , int )
1109
+ total_detailed_tokens = sum (
1110
+ v for v in input_token_details .values () if isinstance (v , int ) and v >= 0
1109
1111
)
1112
+ input_tokens = usage_metadata .get ("input_tokens" , 0 )
1113
+ assert isinstance (input_tokens , int )
1114
+ assert input_tokens >= total_detailed_tokens
1110
1115
if "cache_creation_input" in self .supported_usage_metadata_details ["invoke" ]:
1111
1116
msg = self .invoke_with_cache_creation_input ()
1112
- assert (usage_metadata := msg .usage_metadata ) is not None
1113
- assert (
1114
- input_token_details := usage_metadata .get ("input_token_details" )
1115
- ) is not None
1116
- assert isinstance (input_token_details .get ("cache_creation" ), int )
1117
+ usage_metadata = msg .usage_metadata
1118
+ assert usage_metadata is not None
1119
+ input_token_details = usage_metadata .get ("input_token_details" )
1120
+ assert input_token_details is not None
1121
+ cache_creation_tokens = input_token_details .get ("cache_creation" )
1122
+ assert isinstance (cache_creation_tokens , int )
1123
+ assert cache_creation_tokens >= 0
1117
1124
# Asserts that total input tokens are at least the sum of the token counts
1118
- assert usage_metadata . get ( "input_tokens" , 0 ) > = sum (
1119
- v for v in input_token_details .values () if isinstance (v , int )
1125
+ total_detailed_tokens = sum (
1126
+ v for v in input_token_details .values () if isinstance (v , int ) and v >= 0
1120
1127
)
1128
+ input_tokens = usage_metadata .get ("input_tokens" , 0 )
1129
+ assert isinstance (input_tokens , int )
1130
+ assert input_tokens >= total_detailed_tokens
1121
1131
1122
1132
def test_usage_metadata_streaming (self , model : BaseChatModel ) -> None :
1123
1133
"""Test usage metadata in streaming mode.
0 commit comments