Skip to content

Commit 35e9d36

Browse files
authored
fix(standard-tests): ensure non-negative token counts in usage metadata assertions (#32593)
1 parent 8b90eae commit 35e9d36

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

libs/standard-tests/langchain_tests/integration_tests/chat_models.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,26 +1098,36 @@ def supported_usage_metadata_details(self) -> dict:
10981098
)
10991099
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
11001100
msg = self.invoke_with_cache_read_input()
1101-
assert (usage_metadata := msg.usage_metadata) is not None
1102-
assert (
1103-
input_token_details := usage_metadata.get("input_token_details")
1104-
) is not None
1105-
assert isinstance(input_token_details.get("cache_read"), int)
1101+
usage_metadata = msg.usage_metadata
1102+
assert usage_metadata is not None
1103+
input_token_details = usage_metadata.get("input_token_details")
1104+
assert input_token_details is not None
1105+
cache_read_tokens = input_token_details.get("cache_read")
1106+
assert isinstance(cache_read_tokens, int)
1107+
assert cache_read_tokens >= 0
11061108
# Asserts that total input tokens are at least the sum of the token counts
1107-
assert usage_metadata.get("input_tokens", 0) >= sum(
1108-
v for v in input_token_details.values() if isinstance(v, int)
1109+
total_detailed_tokens = sum(
1110+
v for v in input_token_details.values() if isinstance(v, int) and v >= 0
11091111
)
1112+
input_tokens = usage_metadata.get("input_tokens", 0)
1113+
assert isinstance(input_tokens, int)
1114+
assert input_tokens >= total_detailed_tokens
11101115
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
11111116
msg = self.invoke_with_cache_creation_input()
1112-
assert (usage_metadata := msg.usage_metadata) is not None
1113-
assert (
1114-
input_token_details := usage_metadata.get("input_token_details")
1115-
) is not None
1116-
assert isinstance(input_token_details.get("cache_creation"), int)
1117+
usage_metadata = msg.usage_metadata
1118+
assert usage_metadata is not None
1119+
input_token_details = usage_metadata.get("input_token_details")
1120+
assert input_token_details is not None
1121+
cache_creation_tokens = input_token_details.get("cache_creation")
1122+
assert isinstance(cache_creation_tokens, int)
1123+
assert cache_creation_tokens >= 0
11171124
# Asserts that total input tokens are at least the sum of the token counts
1118-
assert usage_metadata.get("input_tokens", 0) >= sum(
1119-
v for v in input_token_details.values() if isinstance(v, int)
1125+
total_detailed_tokens = sum(
1126+
v for v in input_token_details.values() if isinstance(v, int) and v >= 0
11201127
)
1128+
input_tokens = usage_metadata.get("input_tokens", 0)
1129+
assert isinstance(input_tokens, int)
1130+
assert input_tokens >= total_detailed_tokens
11211131

11221132
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
11231133
"""Test usage metadata in streaming mode.

0 commit comments

Comments
 (0)