5050 from sentry_sdk ._types import TextPart
5151
5252
53+ class _RecordedUsage :
54+ output_tokens : int = 0
55+ input_tokens : int = 0
56+ cache_write_input_tokens : "Optional[int]" = 0
57+ cache_read_input_tokens : "Optional[int]" = 0
58+
59+
5360class AnthropicIntegration (Integration ):
5461 identifier = "anthropic"
5562 origin = f"auto.ai.{ identifier } "
@@ -112,31 +119,15 @@ def _get_token_usage(result: "Messages") -> "tuple[int, int, int, int]":
112119def _collect_ai_data (
113120 event : "MessageStreamEvent" ,
114121 model : "str | None" ,
115- input_tokens : int ,
116- output_tokens : int ,
117- cache_read_input_tokens : int ,
118- cache_write_input_tokens : int ,
122+ usage : "_RecordedUsage" ,
119123 content_blocks : "list[str]" ,
120- ) -> "tuple[str | None, int, int, int, int , list[str]]" :
124+ ) -> "tuple[str | None, _RecordedUsage , list[str]]" :
121125 """
122126 Collect model information, token usage, and collect content blocks from the AI streaming response.
123127 """
124128 with capture_internal_exceptions ():
125129 if hasattr (event , "type" ):
126- if event .type == "message_start" :
127- usage = event .message .usage
128- input_tokens += usage .input_tokens
129- output_tokens += usage .output_tokens
130- if hasattr (usage , "cache_read_input_tokens" ) and isinstance (
131- usage .cache_read_input_tokens , int
132- ):
133- cache_read_input_tokens += usage .cache_read_input_tokens
134- if hasattr (usage , "cache_creation_input_tokens" ) and isinstance (
135- usage .cache_creation_input_tokens , int
136- ):
137- cache_write_input_tokens += usage .cache_creation_input_tokens
138- model = event .message .model or model
139- elif event .type == "content_block_start" :
130+ if event .type == "content_block_start" :
140131 pass
141132 elif event .type == "content_block_delta" :
142133 if hasattr (event .delta , "text" ):
@@ -145,15 +136,60 @@ def _collect_ai_data(
145136 content_blocks .append (event .delta .partial_json )
146137 elif event .type == "content_block_stop" :
147138 pass
148- elif event .type == "message_delta" :
149- output_tokens += event .usage .output_tokens
139+
140+ # Token counting logic mirrors anthropic SDK, which also extracts already accumulated tokens.
141+ # https://github.com/anthropics/anthropic-sdk-python/blob/9c485f6966e10ae0ea9eabb3a921d2ea8145a25b/src/anthropic/lib/streaming/_messages.py#L433-L518
142+ if event .type == "message_start" :
143+ model = event .message .model or model
144+
145+ incoming_usage = event .message .usage
146+ usage .output_tokens = incoming_usage .output_tokens
147+ usage .input_tokens = incoming_usage .input_tokens
148+
149+ usage .cache_write_input_tokens = getattr (
150+ incoming_usage , "cache_creation_input_tokens" , None
151+ )
152+ usage .cache_read_input_tokens = getattr (
153+ incoming_usage , "cache_read_input_tokens" , None
154+ )
155+
156+ return (
157+ model ,
158+ usage ,
159+ content_blocks ,
160+ )
161+
162+ # Counterintuitive, but message_delta contains cumulative token counts :)
163+ if event .type == "message_delta" :
164+ usage .output_tokens = event .usage .output_tokens
165+
166+ # Update other usage fields if they exist in the event
167+ input_tokens = getattr (event .usage , "input_tokens" , None )
168+ if input_tokens is not None :
169+ usage .input_tokens = input_tokens
170+
171+ cache_creation_input_tokens = getattr (
172+ event .usage , "cache_creation_input_tokens" , None
173+ )
174+ if cache_creation_input_tokens is not None :
175+ usage .cache_write_input_tokens = cache_creation_input_tokens
176+
177+ cache_read_input_tokens = getattr (
178+ event .usage , "cache_read_input_tokens" , None
179+ )
180+ if cache_read_input_tokens is not None :
181+ usage .cache_read_input_tokens = cache_read_input_tokens
182+ # TODO: Record event.usage.server_tool_use
183+
184+ return (
185+ model ,
186+ usage ,
187+ content_blocks ,
188+ )
150189
151190 return (
152191 model ,
153- input_tokens ,
154- output_tokens ,
155- cache_read_input_tokens ,
156- cache_write_input_tokens ,
192+ usage ,
157193 content_blocks ,
158194 )
159195
@@ -414,27 +450,18 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
414450
415451 def new_iterator () -> "Iterator[MessageStreamEvent]" :
416452 model = None
417- input_tokens = 0
418- output_tokens = 0
419- cache_read_input_tokens = 0
420- cache_write_input_tokens = 0
453+ usage = _RecordedUsage ()
421454 content_blocks : "list[str]" = []
422455
423456 for event in old_iterator :
424457 (
425458 model ,
426- input_tokens ,
427- output_tokens ,
428- cache_read_input_tokens ,
429- cache_write_input_tokens ,
459+ usage ,
430460 content_blocks ,
431461 ) = _collect_ai_data (
432462 event ,
433463 model ,
434- input_tokens ,
435- output_tokens ,
436- cache_read_input_tokens ,
437- cache_write_input_tokens ,
464+ usage ,
438465 content_blocks ,
439466 )
440467 yield event
@@ -443,37 +470,28 @@ def new_iterator() -> "Iterator[MessageStreamEvent]":
443470 span = span ,
444471 integration = integration ,
445472 model = model ,
446- input_tokens = input_tokens ,
447- output_tokens = output_tokens ,
448- cache_read_input_tokens = cache_read_input_tokens ,
449- cache_write_input_tokens = cache_write_input_tokens ,
473+ input_tokens = usage . input_tokens ,
474+ output_tokens = usage . output_tokens ,
475+ cache_read_input_tokens = usage . cache_read_input_tokens ,
476+ cache_write_input_tokens = usage . cache_write_input_tokens ,
450477 content_blocks = [{"text" : "" .join (content_blocks ), "type" : "text" }],
451478 finish_span = True ,
452479 )
453480
454481 async def new_iterator_async () -> "AsyncIterator[MessageStreamEvent]" :
455482 model = None
456- input_tokens = 0
457- output_tokens = 0
458- cache_read_input_tokens = 0
459- cache_write_input_tokens = 0
483+ usage = _RecordedUsage ()
460484 content_blocks : "list[str]" = []
461485
462486 async for event in old_iterator :
463487 (
464488 model ,
465- input_tokens ,
466- output_tokens ,
467- cache_read_input_tokens ,
468- cache_write_input_tokens ,
489+ usage ,
469490 content_blocks ,
470491 ) = _collect_ai_data (
471492 event ,
472493 model ,
473- input_tokens ,
474- output_tokens ,
475- cache_read_input_tokens ,
476- cache_write_input_tokens ,
494+ usage ,
477495 content_blocks ,
478496 )
479497 yield event
@@ -482,10 +500,10 @@ async def new_iterator_async() -> "AsyncIterator[MessageStreamEvent]":
482500 span = span ,
483501 integration = integration ,
484502 model = model ,
485- input_tokens = input_tokens ,
486- output_tokens = output_tokens ,
487- cache_read_input_tokens = cache_read_input_tokens ,
488- cache_write_input_tokens = cache_write_input_tokens ,
503+ input_tokens = usage . input_tokens ,
504+ output_tokens = usage . output_tokens ,
505+ cache_read_input_tokens = usage . cache_read_input_tokens ,
506+ cache_write_input_tokens = usage . cache_write_input_tokens ,
489507 content_blocks = [{"text" : "" .join (content_blocks ), "type" : "text" }],
490508 finish_span = True ,
491509 )
0 commit comments