Skip to content

Commit 78c6011

Browse files
fix(anthropic): Token reporting (#5403)
Stop double accumulating tokens, as Anthropic returns cumulative tokens in its streaming chunks. Adapts logic from the Anthropic Python SDK, and introduces the `_RecordedUsage` class to mirror how Anthropic's SDK mutates a usage instance.
1 parent 1e655ba commit 78c6011

File tree

2 files changed

+97
-77
lines changed

2 files changed

+97
-77
lines changed

sentry_sdk/integrations/anthropic.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@
5050
from sentry_sdk._types import TextPart
5151

5252

53+
class _RecordedUsage:
54+
output_tokens: int = 0
55+
input_tokens: int = 0
56+
cache_write_input_tokens: "Optional[int]" = 0
57+
cache_read_input_tokens: "Optional[int]" = 0
58+
59+
5360
class AnthropicIntegration(Integration):
5461
identifier = "anthropic"
5562
origin = f"auto.ai.{identifier}"
@@ -112,31 +119,15 @@ def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
112119
def _collect_ai_data(
113120
event: "MessageStreamEvent",
114121
model: "str | None",
115-
input_tokens: int,
116-
output_tokens: int,
117-
cache_read_input_tokens: int,
118-
cache_write_input_tokens: int,
122+
usage: "_RecordedUsage",
119123
content_blocks: "list[str]",
120-
) -> "tuple[str | None, int, int, int, int, list[str]]":
124+
) -> "tuple[str | None, _RecordedUsage, list[str]]":
121125
"""
122126
Collect model information, token usage, and collect content blocks from the AI streaming response.
123127
"""
124128
with capture_internal_exceptions():
125129
if hasattr(event, "type"):
126-
if event.type == "message_start":
127-
usage = event.message.usage
128-
input_tokens += usage.input_tokens
129-
output_tokens += usage.output_tokens
130-
if hasattr(usage, "cache_read_input_tokens") and isinstance(
131-
usage.cache_read_input_tokens, int
132-
):
133-
cache_read_input_tokens += usage.cache_read_input_tokens
134-
if hasattr(usage, "cache_creation_input_tokens") and isinstance(
135-
usage.cache_creation_input_tokens, int
136-
):
137-
cache_write_input_tokens += usage.cache_creation_input_tokens
138-
model = event.message.model or model
139-
elif event.type == "content_block_start":
130+
if event.type == "content_block_start":
140131
pass
141132
elif event.type == "content_block_delta":
142133
if hasattr(event.delta, "text"):
@@ -145,15 +136,60 @@ def _collect_ai_data(
145136
content_blocks.append(event.delta.partial_json)
146137
elif event.type == "content_block_stop":
147138
pass
148-
elif event.type == "message_delta":
149-
output_tokens += event.usage.output_tokens
139+
140+
# Token counting logic mirrors anthropic SDK, which also extracts already accumulated tokens.
141+
# https://github.com/anthropics/anthropic-sdk-python/blob/9c485f6966e10ae0ea9eabb3a921d2ea8145a25b/src/anthropic/lib/streaming/_messages.py#L433-L518
142+
if event.type == "message_start":
143+
model = event.message.model or model
144+
145+
incoming_usage = event.message.usage
146+
usage.output_tokens = incoming_usage.output_tokens
147+
usage.input_tokens = incoming_usage.input_tokens
148+
149+
usage.cache_write_input_tokens = getattr(
150+
incoming_usage, "cache_creation_input_tokens", None
151+
)
152+
usage.cache_read_input_tokens = getattr(
153+
incoming_usage, "cache_read_input_tokens", None
154+
)
155+
156+
return (
157+
model,
158+
usage,
159+
content_blocks,
160+
)
161+
162+
# Counterintuitive, but message_delta contains cumulative token counts :)
163+
if event.type == "message_delta":
164+
usage.output_tokens = event.usage.output_tokens
165+
166+
# Update other usage fields if they exist in the event
167+
input_tokens = getattr(event.usage, "input_tokens", None)
168+
if input_tokens is not None:
169+
usage.input_tokens = input_tokens
170+
171+
cache_creation_input_tokens = getattr(
172+
event.usage, "cache_creation_input_tokens", None
173+
)
174+
if cache_creation_input_tokens is not None:
175+
usage.cache_write_input_tokens = cache_creation_input_tokens
176+
177+
cache_read_input_tokens = getattr(
178+
event.usage, "cache_read_input_tokens", None
179+
)
180+
if cache_read_input_tokens is not None:
181+
usage.cache_read_input_tokens = cache_read_input_tokens
182+
# TODO: Record event.usage.server_tool_use
183+
184+
return (
185+
model,
186+
usage,
187+
content_blocks,
188+
)
150189

151190
return (
152191
model,
153-
input_tokens,
154-
output_tokens,
155-
cache_read_input_tokens,
156-
cache_write_input_tokens,
192+
usage,
157193
content_blocks,
158194
)
159195

@@ -414,27 +450,18 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
414450

415451
def new_iterator() -> "Iterator[MessageStreamEvent]":
416452
model = None
417-
input_tokens = 0
418-
output_tokens = 0
419-
cache_read_input_tokens = 0
420-
cache_write_input_tokens = 0
453+
usage = _RecordedUsage()
421454
content_blocks: "list[str]" = []
422455

423456
for event in old_iterator:
424457
(
425458
model,
426-
input_tokens,
427-
output_tokens,
428-
cache_read_input_tokens,
429-
cache_write_input_tokens,
459+
usage,
430460
content_blocks,
431461
) = _collect_ai_data(
432462
event,
433463
model,
434-
input_tokens,
435-
output_tokens,
436-
cache_read_input_tokens,
437-
cache_write_input_tokens,
464+
usage,
438465
content_blocks,
439466
)
440467
yield event
@@ -443,37 +470,28 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
443470
span=span,
444471
integration=integration,
445472
model=model,
446-
input_tokens=input_tokens,
447-
output_tokens=output_tokens,
448-
cache_read_input_tokens=cache_read_input_tokens,
449-
cache_write_input_tokens=cache_write_input_tokens,
473+
input_tokens=usage.input_tokens,
474+
output_tokens=usage.output_tokens,
475+
cache_read_input_tokens=usage.cache_read_input_tokens,
476+
cache_write_input_tokens=usage.cache_write_input_tokens,
450477
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
451478
finish_span=True,
452479
)
453480

454481
async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
455482
model = None
456-
input_tokens = 0
457-
output_tokens = 0
458-
cache_read_input_tokens = 0
459-
cache_write_input_tokens = 0
483+
usage = _RecordedUsage()
460484
content_blocks: "list[str]" = []
461485

462486
async for event in old_iterator:
463487
(
464488
model,
465-
input_tokens,
466-
output_tokens,
467-
cache_read_input_tokens,
468-
cache_write_input_tokens,
489+
usage,
469490
content_blocks,
470491
) = _collect_ai_data(
471492
event,
472493
model,
473-
input_tokens,
474-
output_tokens,
475-
cache_read_input_tokens,
476-
cache_write_input_tokens,
494+
usage,
477495
content_blocks,
478496
)
479497
yield event
@@ -482,10 +500,10 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
482500
span=span,
483501
integration=integration,
484502
model=model,
485-
input_tokens=input_tokens,
486-
output_tokens=output_tokens,
487-
cache_read_input_tokens=cache_read_input_tokens,
488-
cache_write_input_tokens=cache_write_input_tokens,
503+
input_tokens=usage.input_tokens,
504+
output_tokens=usage.output_tokens,
505+
cache_read_input_tokens=usage.cache_read_input_tokens,
506+
cache_write_input_tokens=usage.cache_write_input_tokens,
489507
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
490508
finish_span=True,
491509
)

tests/integrations/anthropic/test_anthropic.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ async def __call__(self, *args, **kwargs):
4949
_set_output_data,
5050
_collect_ai_data,
5151
_transform_anthropic_content_block,
52+
_RecordedUsage,
5253
)
5354
from sentry_sdk.ai.utils import transform_content_part, transform_message_content
5455
from sentry_sdk.utils import package_version
@@ -307,8 +308,8 @@ def test_streaming_create_message(
307308
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
308309

309310
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
310-
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
311-
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
311+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
312+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
312313
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
313314

314315

@@ -412,8 +413,8 @@ async def test_streaming_create_message_async(
412413
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
413414

414415
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
415-
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
416-
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
416+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
417+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
417418
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
418419

419420

@@ -546,8 +547,8 @@ def test_streaming_create_message_with_input_json_delta(
546547
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
547548

548549
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
549-
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 51
550-
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 417
550+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
551+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
551552
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
552553

553554

@@ -688,8 +689,8 @@ async def test_streaming_create_message_with_input_json_delta_async(
688689
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
689690

690691
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 366
691-
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 51
692-
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 417
692+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 41
693+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 407
693694
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
694695

695696

@@ -849,18 +850,19 @@ def test_collect_ai_data_with_input_json_delta():
849850
type="content_block_delta",
850851
)
851852
model = None
852-
input_tokens = 10
853-
output_tokens = 20
853+
854+
usage = _RecordedUsage()
855+
usage.output_tokens = 20
856+
usage.input_tokens = 10
857+
854858
content_blocks = []
855859

856-
model, new_input_tokens, new_output_tokens, _, _, new_content_blocks = (
857-
_collect_ai_data(
858-
event, model, input_tokens, output_tokens, 0, 0, content_blocks
859-
)
860+
model, new_usage, new_content_blocks = _collect_ai_data(
861+
event, model, usage, content_blocks
860862
)
861863
assert model is None
862-
assert new_input_tokens == input_tokens
863-
assert new_output_tokens == output_tokens
864+
assert new_usage.input_tokens == usage.input_tokens
865+
assert new_usage.output_tokens == usage.output_tokens
864866
assert new_content_blocks == ["test"]
865867

866868

@@ -1345,8 +1347,8 @@ def test_streaming_create_message_with_system_prompt(
13451347
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
13461348

13471349
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
1348-
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
1349-
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
1350+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
1351+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
13501352
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
13511353

13521354

@@ -1465,8 +1467,8 @@ async def test_streaming_create_message_with_system_prompt_async(
14651467
assert SPANDATA.GEN_AI_RESPONSE_TEXT not in span["data"]
14661468

14671469
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
1468-
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 30
1469-
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 40
1470+
assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
1471+
assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 20
14701472
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
14711473

14721474

0 commit comments

Comments
 (0)