Skip to content

Commit 49addb8

Browse files
committed
llc enhance
1 parent f6d433b commit 49addb8

File tree

6 files changed

+143
-50
lines changed

6 files changed

+143
-50
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ venv.bak/
3535
/.vscode
3636
/output
3737
dist/
38-
.coda/
38+
.coda/
39+
.DS_Store

CHANGLOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## [0.1.22] - 2026-01-05
2+
### Added
3+
- lcc support child_of and state_span_ctx_key
4+
- lcc support multi clients
5+
- llc support get trace_id and root_span_id
6+
17
## [0.1.21] - 2025-12-23
28
### Added
39
- runtime scene support get from env

cozeloop/integration/langchain/trace_callback.py

Lines changed: 98 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55
import json
6+
import threading
67
import time
78
import traceback
89
from typing import List, Dict, Union, Any, Optional, Callable, Protocol
@@ -24,8 +25,6 @@
2425
from cozeloop.integration.langchain.trace_model.runtime import RuntimeInfo
2526
from cozeloop.integration.langchain.util import calc_token_usage, get_prompt_tag
2627

27-
_trace_callback_client: Optional[Client] = None
28-
2928

3029
class LoopTracer:
3130
@classmethod
@@ -35,19 +34,28 @@ def get_callback_handler(
3534
modify_name_fn: Optional[Callable[[str], str]] = None,
3635
add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
3736
tags: Dict[str, Any] = None,
37+
child_of: Optional[Span] = None,
38+
state_span_ctx_key: str = None,
3839
):
3940
"""
4041
Do not hold it for a long time, get a new callback_handler for each request.
41-
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42-
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
42+
client: cozeloop client instance. If not provided, use the default client.
43+
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
44+
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
45+
It's priority higher than parameter tags.
46+
tags: default tags dict. It's priority lower than parameter add_tags_fn.
47+
child_of: parent span of this callback_handler.
48+
state_span_ctx_key: span context field name in state. If provided, you need set the field in sate, and we will use it to set span context in state.
49+
You can get it from state for creating inner span in async node.
4350
"""
44-
global _trace_callback_client
45-
if client:
46-
_trace_callback_client = client
47-
else:
48-
_trace_callback_client = get_default_client()
49-
50-
return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags)
51+
return LoopTraceCallbackHandler(
52+
name_fn=modify_name_fn,
53+
tags_fn=add_tags_fn,
54+
tags=tags,
55+
child_of=child_of,
56+
client=client,
57+
state_span_ctx_key=state_span_ctx_key,
58+
)
5159

5260

5361
class LoopTraceCallbackHandler(BaseCallbackHandler):
@@ -56,13 +64,22 @@ def __init__(
5664
name_fn: Optional[Callable[[str], str]] = None,
5765
tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
5866
tags: Dict[str, Any] = None,
67+
child_of: Optional[Span] = None,
68+
client: Client = None,
69+
state_span_ctx_key: str = None,
5970
):
6071
super().__init__()
61-
self._space_id = _trace_callback_client.workspace_id
72+
self._client = client if client else get_default_client()
73+
self._space_id = self._client.workspace_id
6274
self.run_map: Dict[str, Run] = {}
6375
self.name_fn = name_fn
6476
self.tags_fn = tags_fn
6577
self._tags = tags if tags else {}
78+
self.trace_id: Optional[str] = None
79+
self.root_span_id: Optional[str] = None
80+
self._id_lock = threading.Lock()
81+
self._child_of = child_of
82+
self._state_span_ctx_key = state_span_ctx_key
6683

6784
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
6885
span_tags = {}
@@ -73,14 +90,16 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs:
7390
span_tags['input'] = ModelTraceInput([BaseMessage(type='', content=prompt) for prompt in prompts],
7491
kwargs.get('invocation_params', {})).to_json()
7592
except Exception as e:
76-
flow_span.set_error(e)
93+
span_tags['internal_error'] = repr(e)
94+
span_tags['internal_error_trace'] = traceback.format_exc()
7795
finally:
7896
span_tags.update(_get_model_span_tags(**kwargs))
7997
self._set_span_tags(flow_span, span_tags)
8098
# Store some pre-aspect information.
8199
self.run_map[str(kwargs['run_id'])].model_meta = ModelMeta(
82100
message=[{'role': '', 'content': prompt} for prompt in prompts],
83101
model_name=span_tags.get('model_name', ''))
102+
return flow_span
84103

85104
def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
86105
span_tags = {}
@@ -90,14 +109,16 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba
90109
try:
91110
span_tags['input'] = ModelTraceInput(messages, kwargs.get('invocation_params', {})).to_json()
92111
except Exception as e:
93-
flow_span.set_error(e)
112+
span_tags['internal_error'] = repr(e)
113+
span_tags['internal_error_trace'] = traceback.format_exc()
94114
finally:
95115
span_tags.update(_get_model_span_tags(**kwargs))
96116
self._set_span_tags(flow_span, span_tags)
97117
# Store some pre-aspect information.
98118
self.run_map[str(kwargs['run_id'])].model_meta = (
99119
ModelMeta(message=[{'role': message.type, 'content': message.content} for inner_messages in messages for
100120
message in inner_messages], model_name=span_tags.get('model_name', '')))
121+
return flow_span
101122

102123
async def on_llm_new_token(self, token: str, *, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
103124
**kwargs: Any) -> None:
@@ -119,10 +140,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
119140
if tags:
120141
self._set_span_tags(flow_span, tags, need_convert_tag_value=False)
121142
except Exception as e:
122-
flow_span.set_error(e)
143+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
144+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
123145
# finish flow_span
124146
self._end_flow_span(flow_span)
125147

148+
def on_llm_error(self, error: Exception, **kwargs: Any) -> Any:
149+
self.on_chain_error(error, **kwargs)
150+
126151
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
127152
flow_span = None
128153
try:
@@ -131,13 +156,27 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k
131156
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
132157
else:
133158
span_type = 'chain'
134-
if kwargs['name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
159+
if kwargs[
160+
'name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
135161
span_type = 'graph'
136162
flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs)
137163
flow_span.set_tags({'input': _convert_2_json(inputs)})
138164
except Exception as e:
139165
if flow_span is not None:
140-
flow_span.set_error(e)
166+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
167+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
168+
finally:
169+
if flow_span is not None:
170+
# set trace_id
171+
with self._id_lock:
172+
if hasattr(flow_span, 'context'):
173+
self.trace_id = flow_span.context.trace_id
174+
else:
175+
self.trace_id = flow_span.trace_id
176+
# set span_ctx in state
177+
if self._state_span_ctx_key:
178+
inputs[self._state_span_ctx_key] = flow_span
179+
return flow_span
141180

142181
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> Any:
143182
flow_span = self.run_map[str(kwargs['run_id'])].span
@@ -151,16 +190,17 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
151190
else:
152191
flow_span.set_tags({'output': _convert_2_json(outputs)})
153192
except Exception as e:
154-
flow_span.set_error(e)
193+
if flow_span:
194+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
195+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
155196
self._end_flow_span(flow_span)
156197

157198
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
158199
flow_span = self._get_flow_span(**kwargs)
159200
if flow_span is None:
160201
span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt'
161202
flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs)
162-
flow_span.set_error(error)
163-
flow_span.set_tags({'error_trace': traceback.format_exc()})
203+
flow_span.set_tags({'error': repr(error), 'error_trace': traceback.format_exc()})
164204
self._end_flow_span(flow_span)
165205

166206
def on_tool_start(
@@ -170,13 +210,15 @@ def on_tool_start(
170210
span_name = serialized.get('name', 'unknown')
171211
flow_span = self._new_flow_span(span_name, 'tool', **kwargs)
172212
self._set_span_tags(flow_span, span_tags)
213+
return flow_span
173214

174215
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
175216
flow_span = self._get_flow_span(**kwargs)
176217
try:
177218
flow_span.set_tags({'output': _convert_2_json(output)})
178219
except Exception as e:
179-
flow_span.set_error(e)
220+
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
221+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
180222
self._end_flow_span(flow_span)
181223

182224
def on_tool_error(
@@ -186,8 +228,8 @@ def on_tool_error(
186228
if flow_span is None:
187229
span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt'
188230
flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs)
189-
flow_span.set_error(error)
190-
flow_span.set_tags({'error_trace': traceback.format_exc()})
231+
span_tags = {'error': repr(error), 'error_trace': traceback.format_exc()}
232+
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
191233
self._end_flow_span(flow_span)
192234

193235
def on_text(self, text: str, **kwargs: Any) -> Any:
@@ -200,7 +242,8 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
200242
return
201243

202244
def _end_flow_span(self, span: Span):
203-
span.finish()
245+
if span:
246+
span.finish()
204247

205248
def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
206249
return self._get_model_token_tags(response, **kwargs)
@@ -224,20 +267,25 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
224267
result['input_cached_tokens'] = input_cached_tokens
225268
elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None:
226269
for i, generation in enumerate(response.generations[0]):
227-
if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
270+
if isinstance(generation, ChatGeneration) and isinstance(generation.message, (
271+
AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
228272
is_get_from_langchain = True
229273
result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0)
230274
result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0)
231275
result['tokens'] = result['input_tokens'] + result['output_tokens']
232276
if generation.message.usage_metadata.get('output_token_details', {}):
233-
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0)
277+
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get(
278+
'reasoning', 0)
234279
if reasoning_tokens:
235280
result['reasoning_tokens'] = reasoning_tokens
236281
if generation.message.usage_metadata.get('input_token_details', {}):
237-
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0)
282+
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get(
283+
'cache_read', 0)
238284
if input_read_cached_tokens:
239285
result['input_cached_tokens'] = input_read_cached_tokens
240-
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0)
286+
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details',
287+
{}).get('cache_creation',
288+
0)
241289
if input_creation_cached_tokens:
242290
result['input_creation_cached_tokens'] = input_creation_cached_tokens
243291
if is_get_from_langchain:
@@ -259,7 +307,8 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
259307
span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()}
260308
return span_tags
261309

262-
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
310+
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str),
311+
**kwargs: Any) -> None:
263312
# get inputs
264313
params: List[Argument] = []
265314
if isinstance(inputs, str):
@@ -309,8 +358,14 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
309358
span_name = node_name
310359
# set parent span
311360
parent_span: Span = None
361+
is_root_span = False
312362
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
313363
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
364+
# only root span use child_of
365+
if parent_span is None:
366+
is_root_span = True
367+
if self._child_of:
368+
parent_span = self._child_of
314369
# modify name
315370
error_tag = {}
316371
try:
@@ -321,15 +376,20 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
321376
except Exception as e:
322377
error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
323378
# new span
324-
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
379+
flow_span = self._client.start_span(span_name, span_type, child_of=parent_span)
380+
if is_root_span:
381+
if hasattr(flow_span, 'context'):
382+
self.root_span_id = flow_span.context.span_id
383+
else:
384+
self.trace_id = flow_span.span_id
325385
run_id = str(kwargs['run_id'])
326386
self.run_map[run_id] = Run(run_id, flow_span, span_type)
327387
# set runtime
328388
flow_span.set_runtime(RuntimeInfo())
329389
# set extra tags
330-
flow_span.set_tags(self._tags) # global tags
390+
flow_span.set_tags(self._tags) # global tags
331391
try:
332-
if self.tags_fn: # add tags fn
392+
if self.tags_fn: # add tags fn
333393
tags = self.tags_fn(node_name)
334394
if isinstance(tags, dict):
335395
flow_span.set_tags(tags)
@@ -365,7 +425,10 @@ def _set_extra_span_tags(self, flow_span: Span, tag_list: list, **kwargs: Any):
365425
class Run:
366426
def __init__(self, run_id: str, span: Span, span_type: str) -> None:
367427
self.run_id = run_id # langchain run_id
368-
self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
428+
if hasattr(span, 'context'):
429+
self.span_id = span.context.span_id
430+
else:
431+
self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
369432
self.span = span
370433
self.span_type = span_type
371434
self.child_runs: List[Run] = Field(default_factory=list)
@@ -519,7 +582,8 @@ def _convert_inputs(inputs: Any) -> Any:
519582
format_inputs['content'] = inputs.content
520583
return format_inputs
521584
if isinstance(inputs, BaseMessage):
522-
message = Message(role=inputs.type, content=inputs.content, tool_calls=inputs.additional_kwargs.get('tool_calls', []))
585+
message = Message(role=inputs.type, content=inputs.content,
586+
tool_calls=inputs.additional_kwargs.get('tool_calls', []))
523587
return message
524588
if isinstance(inputs, ChatPromptValue):
525589
return _convert_inputs(inputs.messages)

0 commit comments

Comments
 (0)