33
44from __future__ import annotations
55import json
6+ import threading
67import time
78import traceback
89from typing import List , Dict , Union , Any , Optional , Callable , Protocol
2425from cozeloop .integration .langchain .trace_model .runtime import RuntimeInfo
2526from cozeloop .integration .langchain .util import calc_token_usage , get_prompt_tag
2627
27- _trace_callback_client : Optional [Client ] = None
28-
2928
3029class LoopTracer :
3130 @classmethod
@@ -35,19 +34,28 @@ def get_callback_handler(
3534 modify_name_fn : Optional [Callable [[str ], str ]] = None ,
3635 add_tags_fn : Optional [Callable [[str ], Dict [str , Any ]]] = None ,
3736 tags : Dict [str , Any ] = None ,
37+ child_of : Optional [Span ] = None ,
38+ state_span_ctx_key : str = None ,
3839 ):
3940 """
4041 Do not hold it for a long time, get a new callback_handler for each request.
41- modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42- add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
42+ client: cozeloop client instance. If not provided, use the default client.
43+ modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
44+ add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
45+ It's priority higher than parameter tags.
46+ tags: default tags dict. It's priority lower than parameter add_tags_fn.
47+ child_of: parent span of this callback_handler.
48+ state_span_ctx_key: span context field name in state. If provided, you need set the field in sate, and we will use it to set span context in state.
49+ You can get it from state for creating inner span in async node.
4350 """
44- global _trace_callback_client
45- if client :
46- _trace_callback_client = client
47- else :
48- _trace_callback_client = get_default_client ()
49-
50- return LoopTraceCallbackHandler (modify_name_fn , add_tags_fn , tags )
51+ return LoopTraceCallbackHandler (
52+ name_fn = modify_name_fn ,
53+ tags_fn = add_tags_fn ,
54+ tags = tags ,
55+ child_of = child_of ,
56+ client = client ,
57+ state_span_ctx_key = state_span_ctx_key ,
58+ )
5159
5260
5361class LoopTraceCallbackHandler (BaseCallbackHandler ):
@@ -56,13 +64,22 @@ def __init__(
5664 name_fn : Optional [Callable [[str ], str ]] = None ,
5765 tags_fn : Optional [Callable [[str ], Dict [str , Any ]]] = None ,
5866 tags : Dict [str , Any ] = None ,
67+ child_of : Optional [Span ] = None ,
68+ client : Client = None ,
69+ state_span_ctx_key : str = None ,
5970 ):
6071 super ().__init__ ()
61- self ._space_id = _trace_callback_client .workspace_id
72+ self ._client = client if client else get_default_client ()
73+ self ._space_id = self ._client .workspace_id
6274 self .run_map : Dict [str , Run ] = {}
6375 self .name_fn = name_fn
6476 self .tags_fn = tags_fn
6577 self ._tags = tags if tags else {}
78+ self .trace_id : Optional [str ] = None
79+ self .root_span_id : Optional [str ] = None
80+ self ._id_lock = threading .Lock ()
81+ self ._child_of = child_of
82+ self ._state_span_ctx_key = state_span_ctx_key
6683
6784 def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
6885 span_tags = {}
@@ -73,14 +90,16 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs:
7390 span_tags ['input' ] = ModelTraceInput ([BaseMessage (type = '' , content = prompt ) for prompt in prompts ],
7491 kwargs .get ('invocation_params' , {})).to_json ()
7592 except Exception as e :
76- flow_span .set_error (e )
93+ span_tags ['internal_error' ] = repr (e )
94+ span_tags ['internal_error_trace' ] = traceback .format_exc ()
7795 finally :
7896 span_tags .update (_get_model_span_tags (** kwargs ))
7997 self ._set_span_tags (flow_span , span_tags )
8098 # Store some pre-aspect information.
8199 self .run_map [str (kwargs ['run_id' ])].model_meta = ModelMeta (
82100 message = [{'role' : '' , 'content' : prompt } for prompt in prompts ],
83101 model_name = span_tags .get ('model_name' , '' ))
102+ return flow_span
84103
85104 def on_chat_model_start (self , serialized : Dict [str , Any ], messages : List [List [BaseMessage ]], ** kwargs : Any ) -> Any :
86105 span_tags = {}
@@ -90,14 +109,16 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba
90109 try :
91110 span_tags ['input' ] = ModelTraceInput (messages , kwargs .get ('invocation_params' , {})).to_json ()
92111 except Exception as e :
93- flow_span .set_error (e )
112+ span_tags ['internal_error' ] = repr (e )
113+ span_tags ['internal_error_trace' ] = traceback .format_exc ()
94114 finally :
95115 span_tags .update (_get_model_span_tags (** kwargs ))
96116 self ._set_span_tags (flow_span , span_tags )
97117 # Store some pre-aspect information.
98118 self .run_map [str (kwargs ['run_id' ])].model_meta = (
99119 ModelMeta (message = [{'role' : message .type , 'content' : message .content } for inner_messages in messages for
100120 message in inner_messages ], model_name = span_tags .get ('model_name' , '' )))
121+ return flow_span
101122
102123 async def on_llm_new_token (self , token : str , * , chunk : Optional [Union [GenerationChunk , ChatGenerationChunk ]] = None ,
103124 ** kwargs : Any ) -> None :
@@ -119,10 +140,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
119140 if tags :
120141 self ._set_span_tags (flow_span , tags , need_convert_tag_value = False )
121142 except Exception as e :
122- flow_span .set_error (e )
143+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
144+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
123145 # finish flow_span
124146 self ._end_flow_span (flow_span )
125147
148+ def on_llm_error (self , error : Exception , ** kwargs : Any ) -> Any :
149+ self .on_chain_error (error , ** kwargs )
150+
126151 def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
127152 flow_span = None
128153 try :
@@ -131,13 +156,27 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k
131156 self ._on_prompt_start (flow_span , serialized , inputs , ** kwargs )
132157 else :
133158 span_type = 'chain'
134- if kwargs ['name' ] == 'LangGraph' : # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
159+ if kwargs [
160+ 'name' ] == 'LangGraph' : # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
135161 span_type = 'graph'
136162 flow_span = self ._new_flow_span (kwargs ['name' ], span_type , ** kwargs )
137163 flow_span .set_tags ({'input' : _convert_2_json (inputs )})
138164 except Exception as e :
139165 if flow_span is not None :
140- flow_span .set_error (e )
166+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
167+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
168+ finally :
169+ if flow_span is not None :
170+ # set trace_id
171+ with self ._id_lock :
172+ if hasattr (flow_span , 'context' ):
173+ self .trace_id = flow_span .context .trace_id
174+ else :
175+ self .trace_id = flow_span .trace_id
176+ # set span_ctx in state
177+ if self ._state_span_ctx_key :
178+ inputs [self ._state_span_ctx_key ] = flow_span
179+ return flow_span
141180
142181 def on_chain_end (self , outputs : Union [Dict [str , Any ], Any ], ** kwargs : Any ) -> Any :
143182 flow_span = self .run_map [str (kwargs ['run_id' ])].span
@@ -151,16 +190,17 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
151190 else :
152191 flow_span .set_tags ({'output' : _convert_2_json (outputs )})
153192 except Exception as e :
154- flow_span .set_error (e )
193+ if flow_span :
194+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
195+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
155196 self ._end_flow_span (flow_span )
156197
157198 def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
158199 flow_span = self ._get_flow_span (** kwargs )
159200 if flow_span is None :
160201 span_name = '_Exception' if isinstance (error , Exception ) else '_KeyboardInterrupt'
161202 flow_span = self ._new_flow_span (span_name , 'chain_error' , ** kwargs )
162- flow_span .set_error (error )
163- flow_span .set_tags ({'error_trace' : traceback .format_exc ()})
203+ flow_span .set_tags ({'error' : repr (error ), 'error_trace' : traceback .format_exc ()})
164204 self ._end_flow_span (flow_span )
165205
166206 def on_tool_start (
@@ -170,13 +210,15 @@ def on_tool_start(
170210 span_name = serialized .get ('name' , 'unknown' )
171211 flow_span = self ._new_flow_span (span_name , 'tool' , ** kwargs )
172212 self ._set_span_tags (flow_span , span_tags )
213+ return flow_span
173214
174215 def on_tool_end (self , output : str , ** kwargs : Any ) -> Any :
175216 flow_span = self ._get_flow_span (** kwargs )
176217 try :
177218 flow_span .set_tags ({'output' : _convert_2_json (output )})
178219 except Exception as e :
179- flow_span .set_error (e )
220+ span_tags = {"internal_error" : repr (e ), 'internal_error_trace' : traceback .format_exc ()}
221+ self ._set_span_tags (flow_span , span_tags , need_convert_tag_value = False )
180222 self ._end_flow_span (flow_span )
181223
182224 def on_tool_error (
@@ -186,8 +228,8 @@ def on_tool_error(
186228 if flow_span is None :
187229 span_name = '_Exception' if isinstance (error , Exception ) else '_KeyboardInterrupt'
188230 flow_span = self ._new_flow_span (span_name , 'tool_error' , ** kwargs )
189- flow_span . set_error (error )
190- flow_span . set_tags ({ 'error_trace' : traceback . format_exc ()} )
231+ span_tags = { 'error' : repr (error ), 'error_trace' : traceback . format_exc ()}
232+ self . _set_span_tags ( flow_span , span_tags , need_convert_tag_value = False )
191233 self ._end_flow_span (flow_span )
192234
193235 def on_text (self , text : str , ** kwargs : Any ) -> Any :
@@ -200,7 +242,8 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
200242 return
201243
202244 def _end_flow_span (self , span : Span ):
203- span .finish ()
245+ if span :
246+ span .finish ()
204247
205248 def _get_model_tags (self , response : LLMResult , ** kwargs : Any ) -> Dict [str , Any ]:
206249 return self ._get_model_token_tags (response , ** kwargs )
@@ -224,20 +267,25 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
224267 result ['input_cached_tokens' ] = input_cached_tokens
225268 elif response .generations is not None and len (response .generations ) > 0 and response .generations [0 ] is not None :
226269 for i , generation in enumerate (response .generations [0 ]):
227- if isinstance (generation , ChatGeneration ) and isinstance (generation .message ,(AIMessageChunk , AIMessage )) and generation .message .usage_metadata :
270+ if isinstance (generation , ChatGeneration ) and isinstance (generation .message , (
271+ AIMessageChunk , AIMessage )) and generation .message .usage_metadata :
228272 is_get_from_langchain = True
229273 result ['input_tokens' ] = generation .message .usage_metadata .get ('input_tokens' , 0 )
230274 result ['output_tokens' ] = generation .message .usage_metadata .get ('output_tokens' , 0 )
231275 result ['tokens' ] = result ['input_tokens' ] + result ['output_tokens' ]
232276 if generation .message .usage_metadata .get ('output_token_details' , {}):
233- reasoning_tokens = generation .message .usage_metadata .get ('output_token_details' , {}).get ('reasoning' , 0 )
277+ reasoning_tokens = generation .message .usage_metadata .get ('output_token_details' , {}).get (
278+ 'reasoning' , 0 )
234279 if reasoning_tokens :
235280 result ['reasoning_tokens' ] = reasoning_tokens
236281 if generation .message .usage_metadata .get ('input_token_details' , {}):
237- input_read_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get ('cache_read' , 0 )
282+ input_read_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get (
283+ 'cache_read' , 0 )
238284 if input_read_cached_tokens :
239285 result ['input_cached_tokens' ] = input_read_cached_tokens
240- input_creation_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get ('cache_creation' , 0 )
286+ input_creation_cached_tokens = generation .message .usage_metadata .get ('input_token_details' ,
287+ {}).get ('cache_creation' ,
288+ 0 )
241289 if input_creation_cached_tokens :
242290 result ['input_creation_cached_tokens' ] = input_creation_cached_tokens
243291 if is_get_from_langchain :
@@ -259,7 +307,8 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
259307 span_tags = {'error_info' : repr (e ), 'error_trace' : traceback .format_exc ()}
260308 return span_tags
261309
262- def _on_prompt_start (self , flow_span , serialized : Dict [str , Any ], inputs : (Dict [str , Any ], str ), ** kwargs : Any ) -> None :
310+ def _on_prompt_start (self , flow_span , serialized : Dict [str , Any ], inputs : (Dict [str , Any ], str ),
311+ ** kwargs : Any ) -> None :
263312 # get inputs
264313 params : List [Argument ] = []
265314 if isinstance (inputs , str ):
@@ -309,8 +358,14 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
309358 span_name = node_name
310359 # set parent span
311360 parent_span : Span = None
361+ is_root_span = False
312362 if 'parent_run_id' in kwargs and kwargs ['parent_run_id' ] is not None and str (kwargs ['parent_run_id' ]) in self .run_map :
313363 parent_span = self .run_map [str (kwargs ['parent_run_id' ])].span
364+ # only root span use child_of
365+ if parent_span is None :
366+ is_root_span = True
367+ if self ._child_of :
368+ parent_span = self ._child_of
314369 # modify name
315370 error_tag = {}
316371 try :
@@ -321,15 +376,20 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
321376 except Exception as e :
322377 error_tag = {'error_info' : f'name_fn error { repr (e )} ' , 'error_trace' : traceback .format_exc ()}
323378 # new span
324- flow_span = _trace_callback_client .start_span (span_name , span_type , child_of = parent_span )
379+ flow_span = self ._client .start_span (span_name , span_type , child_of = parent_span )
380+ if is_root_span :
381+ if hasattr (flow_span , 'context' ):
382+ self .root_span_id = flow_span .context .span_id
383+ else :
384+ self .trace_id = flow_span .span_id
325385 run_id = str (kwargs ['run_id' ])
326386 self .run_map [run_id ] = Run (run_id , flow_span , span_type )
327387 # set runtime
328388 flow_span .set_runtime (RuntimeInfo ())
329389 # set extra tags
330- flow_span .set_tags (self ._tags ) # global tags
390+ flow_span .set_tags (self ._tags ) # global tags
331391 try :
332- if self .tags_fn : # add tags fn
392+ if self .tags_fn : # add tags fn
333393 tags = self .tags_fn (node_name )
334394 if isinstance (tags , dict ):
335395 flow_span .set_tags (tags )
@@ -365,7 +425,10 @@ def _set_extra_span_tags(self, flow_span: Span, tag_list: list, **kwargs: Any):
365425class Run :
366426 def __init__ (self , run_id : str , span : Span , span_type : str ) -> None :
367427 self .run_id = run_id # langchain run_id
368- self .span_id = span .span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
428+ if hasattr (span , 'context' ):
429+ self .span_id = span .context .span_id
430+ else :
431+ self .span_id = span .span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
369432 self .span = span
370433 self .span_type = span_type
371434 self .child_runs : List [Run ] = Field (default_factory = list )
@@ -519,7 +582,8 @@ def _convert_inputs(inputs: Any) -> Any:
519582 format_inputs ['content' ] = inputs .content
520583 return format_inputs
521584 if isinstance (inputs , BaseMessage ):
522- message = Message (role = inputs .type , content = inputs .content , tool_calls = inputs .additional_kwargs .get ('tool_calls' , []))
585+ message = Message (role = inputs .type , content = inputs .content ,
586+ tool_calls = inputs .additional_kwargs .get ('tool_calls' , []))
523587 return message
524588 if isinstance (inputs , ChatPromptValue ):
525589 return _convert_inputs (inputs .messages )
0 commit comments