Skip to content

Commit 711e717

Browse files
chore: improve hook typing and registration
Allow hook registration to accept both typed hook types and plain callables by importing and using After*/Before*CallHookCallable types; add explicit LLMCallHookContext and ToolCallHookContext typing in crew_base. Introduce a post-initialize crew hook list and invoke hooks after Crew instance initialization. Refactor filtered hook factory functions to include precise typing and clearer local names (before_llm_hook/after_llm_hook/before_tool_hook/after_tool_hook) and register those with the instance. Update CrewInstance protocol to include _registered_hook_functions and _hooks_being_registered fields.
1 parent 76b5f72 commit 711e717

File tree

4 files changed

+89
-42
lines changed

4 files changed

+89
-42
lines changed

lib/crewai/src/crewai/hooks/llm_hooks.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from typing import TYPE_CHECKING, Any, cast
44

55
from crewai.events.event_listener import event_listener
6-
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
6+
from crewai.hooks.types import (
7+
AfterLLMCallHookCallable,
8+
AfterLLMCallHookType,
9+
BeforeLLMCallHookCallable,
10+
BeforeLLMCallHookType,
11+
)
712
from crewai.utilities.printer import Printer
813

914

@@ -149,12 +154,12 @@ def request_human_input(
149154
event_listener.formatter.resume_live_updates()
150155

151156

152-
_before_llm_call_hooks: list[BeforeLLMCallHookType] = []
153-
_after_llm_call_hooks: list[AfterLLMCallHookType] = []
157+
_before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = []
158+
_after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = []
154159

155160

156161
def register_before_llm_call_hook(
157-
hook: BeforeLLMCallHookType,
162+
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
158163
) -> None:
159164
"""Register a global before_llm_call hook.
160165
@@ -190,7 +195,7 @@ def register_before_llm_call_hook(
190195

191196

192197
def register_after_llm_call_hook(
193-
hook: AfterLLMCallHookType,
198+
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
194199
) -> None:
195200
"""Register a global after_llm_call hook.
196201
@@ -217,7 +222,9 @@ def register_after_llm_call_hook(
217222
_after_llm_call_hooks.append(hook)
218223

219224

220-
def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
225+
def get_before_llm_call_hooks() -> list[
226+
BeforeLLMCallHookType | BeforeLLMCallHookCallable
227+
]:
221228
"""Get all registered global before_llm_call hooks.
222229
223230
Returns:
@@ -226,7 +233,7 @@ def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
226233
return _before_llm_call_hooks.copy()
227234

228235

229-
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
236+
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType | AfterLLMCallHookCallable]:
230237
"""Get all registered global after_llm_call hooks.
231238
232239
Returns:
@@ -236,7 +243,7 @@ def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
236243

237244

238245
def unregister_before_llm_call_hook(
239-
hook: BeforeLLMCallHookType,
246+
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
240247
) -> bool:
241248
"""Unregister a specific global before_llm_call hook.
242249
@@ -262,7 +269,7 @@ def unregister_before_llm_call_hook(
262269

263270

264271
def unregister_after_llm_call_hook(
265-
hook: AfterLLMCallHookType,
272+
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
266273
) -> bool:
267274
"""Unregister a specific global after_llm_call hook.
268275

lib/crewai/src/crewai/hooks/tool_hooks.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from typing import TYPE_CHECKING, Any
44

55
from crewai.events.event_listener import event_listener
6-
from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType
6+
from crewai.hooks.types import (
7+
AfterToolCallHookCallable,
8+
AfterToolCallHookType,
9+
BeforeToolCallHookCallable,
10+
BeforeToolCallHookType,
11+
)
712
from crewai.utilities.printer import Printer
813

914

@@ -112,12 +117,12 @@ def request_human_input(
112117

113118

114119
# Global hook registries
115-
_before_tool_call_hooks: list[BeforeToolCallHookType] = []
116-
_after_tool_call_hooks: list[AfterToolCallHookType] = []
120+
_before_tool_call_hooks: list[BeforeToolCallHookType | BeforeToolCallHookCallable] = []
121+
_after_tool_call_hooks: list[AfterToolCallHookType | AfterToolCallHookCallable] = []
117122

118123

119124
def register_before_tool_call_hook(
120-
hook: BeforeToolCallHookType,
125+
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
121126
) -> None:
122127
"""Register a global before_tool_call hook.
123128
@@ -154,7 +159,7 @@ def register_before_tool_call_hook(
154159

155160

156161
def register_after_tool_call_hook(
157-
hook: AfterToolCallHookType,
162+
hook: AfterToolCallHookType | AfterToolCallHookCallable,
158163
) -> None:
159164
"""Register a global after_tool_call hook.
160165
@@ -184,7 +189,9 @@ def register_after_tool_call_hook(
184189
_after_tool_call_hooks.append(hook)
185190

186191

187-
def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
192+
def get_before_tool_call_hooks() -> list[
193+
BeforeToolCallHookType | BeforeToolCallHookCallable
194+
]:
188195
"""Get all registered global before_tool_call hooks.
189196
190197
Returns:
@@ -193,7 +200,9 @@ def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
193200
return _before_tool_call_hooks.copy()
194201

195202

196-
def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
203+
def get_after_tool_call_hooks() -> list[
204+
AfterToolCallHookType | AfterToolCallHookCallable
205+
]:
197206
"""Get all registered global after_tool_call hooks.
198207
199208
Returns:
@@ -203,7 +212,7 @@ def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
203212

204213

205214
def unregister_before_tool_call_hook(
206-
hook: BeforeToolCallHookType,
215+
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
207216
) -> bool:
208217
"""Unregister a specific global before_tool_call hook.
209218
@@ -229,7 +238,7 @@ def unregister_before_tool_call_hook(
229238

230239

231240
def unregister_after_tool_call_hook(
232-
hook: AfterToolCallHookType,
241+
hook: AfterToolCallHookType | AfterToolCallHookCallable,
233242
) -> bool:
234243
"""Unregister a specific global after_tool_call hook.
235244

lib/crewai/src/crewai/project/crew_base.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@
2727
from crewai import Agent, Task
2828
from crewai.agents.cache.cache_handler import CacheHandler
2929
from crewai.crews.crew_output import CrewOutput
30+
from crewai.hooks.llm_hooks import LLMCallHookContext
31+
from crewai.hooks.tool_hooks import ToolCallHookContext
3032
from crewai.project.wrappers import (
3133
CrewInstance,
3234
OutputJsonClass,
3335
OutputPydanticClass,
3436
)
3537
from crewai.tasks.task_output import TaskOutput
3638

39+
_post_initialize_crew_hooks: list[Callable[[Any], None]] = []
40+
3741

3842
class AgentConfig(TypedDict, total=False):
3943
"""Type definition for agent configuration dictionary.
@@ -266,6 +270,9 @@ def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
266270
instance.map_all_agent_variables()
267271
instance.map_all_task_variables()
268272

273+
for hook in _post_initialize_crew_hooks:
274+
hook(instance)
275+
269276
original_methods = {
270277
name: method
271278
for name, method in cls.__dict__.items()
@@ -485,47 +492,61 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
485492
if has_agent_filter:
486493
agents_filter = hook_method._filter_agents
487494

488-
def make_filtered_before_llm(bound_fn, agents_list):
489-
def filtered(context):
495+
def make_filtered_before_llm(
496+
bound_fn: Callable[[LLMCallHookContext], bool | None],
497+
agents_list: list[str],
498+
) -> Callable[[LLMCallHookContext], bool | None]:
499+
def filtered(context: LLMCallHookContext) -> bool | None:
490500
if context.agent and context.agent.role not in agents_list:
491501
return None
492502
return bound_fn(context)
493503

494504
return filtered
495505

496-
final_hook = make_filtered_before_llm(bound_hook, agents_filter)
506+
before_llm_hook = make_filtered_before_llm(bound_hook, agents_filter)
497507
else:
498-
final_hook = bound_hook
508+
before_llm_hook = bound_hook
499509

500-
register_before_llm_call_hook(final_hook)
501-
instance._registered_hook_functions.append(("before_llm_call", final_hook))
510+
register_before_llm_call_hook(before_llm_hook)
511+
instance._registered_hook_functions.append(
512+
("before_llm_call", before_llm_hook)
513+
)
502514

503515
if hasattr(hook_method, "is_after_llm_call_hook"):
504516
if has_agent_filter:
505517
agents_filter = hook_method._filter_agents
506518

507-
def make_filtered_after_llm(bound_fn, agents_list):
508-
def filtered(context):
519+
def make_filtered_after_llm(
520+
bound_fn: Callable[[LLMCallHookContext], str | None],
521+
agents_list: list[str],
522+
) -> Callable[[LLMCallHookContext], str | None]:
523+
def filtered(context: LLMCallHookContext) -> str | None:
509524
if context.agent and context.agent.role not in agents_list:
510525
return None
511526
return bound_fn(context)
512527

513528
return filtered
514529

515-
final_hook = make_filtered_after_llm(bound_hook, agents_filter)
530+
after_llm_hook = make_filtered_after_llm(bound_hook, agents_filter)
516531
else:
517-
final_hook = bound_hook
532+
after_llm_hook = bound_hook
518533

519-
register_after_llm_call_hook(final_hook)
520-
instance._registered_hook_functions.append(("after_llm_call", final_hook))
534+
register_after_llm_call_hook(after_llm_hook)
535+
instance._registered_hook_functions.append(
536+
("after_llm_call", after_llm_hook)
537+
)
521538

522539
if hasattr(hook_method, "is_before_tool_call_hook"):
523540
if has_tool_filter or has_agent_filter:
524541
tools_filter = getattr(hook_method, "_filter_tools", None)
525542
agents_filter = getattr(hook_method, "_filter_agents", None)
526543

527-
def make_filtered_before_tool(bound_fn, tools_list, agents_list):
528-
def filtered(context):
544+
def make_filtered_before_tool(
545+
bound_fn: Callable[[ToolCallHookContext], bool | None],
546+
tools_list: list[str] | None,
547+
agents_list: list[str] | None,
548+
) -> Callable[[ToolCallHookContext], bool | None]:
549+
def filtered(context: ToolCallHookContext) -> bool | None:
529550
if tools_list and context.tool_name not in tools_list:
530551
return None
531552
if (
@@ -538,22 +559,28 @@ def filtered(context):
538559

539560
return filtered
540561

541-
final_hook = make_filtered_before_tool(
562+
before_tool_hook = make_filtered_before_tool(
542563
bound_hook, tools_filter, agents_filter
543564
)
544565
else:
545-
final_hook = bound_hook
566+
before_tool_hook = bound_hook
546567

547-
register_before_tool_call_hook(final_hook)
548-
instance._registered_hook_functions.append(("before_tool_call", final_hook))
568+
register_before_tool_call_hook(before_tool_hook)
569+
instance._registered_hook_functions.append(
570+
("before_tool_call", before_tool_hook)
571+
)
549572

550573
if hasattr(hook_method, "is_after_tool_call_hook"):
551574
if has_tool_filter or has_agent_filter:
552575
tools_filter = getattr(hook_method, "_filter_tools", None)
553576
agents_filter = getattr(hook_method, "_filter_agents", None)
554577

555-
def make_filtered_after_tool(bound_fn, tools_list, agents_list):
556-
def filtered(context):
578+
def make_filtered_after_tool(
579+
bound_fn: Callable[[ToolCallHookContext], str | None],
580+
tools_list: list[str] | None,
581+
agents_list: list[str] | None,
582+
) -> Callable[[ToolCallHookContext], str | None]:
583+
def filtered(context: ToolCallHookContext) -> str | None:
557584
if tools_list and context.tool_name not in tools_list:
558585
return None
559586
if (
@@ -566,14 +593,16 @@ def filtered(context):
566593

567594
return filtered
568595

569-
final_hook = make_filtered_after_tool(
596+
after_tool_hook = make_filtered_after_tool(
570597
bound_hook, tools_filter, agents_filter
571598
)
572599
else:
573-
final_hook = bound_hook
600+
after_tool_hook = bound_hook
574601

575-
register_after_tool_call_hook(final_hook)
576-
instance._registered_hook_functions.append(("after_tool_call", final_hook))
602+
register_after_tool_call_hook(after_tool_hook)
603+
instance._registered_hook_functions.append(
604+
("after_tool_call", after_tool_hook)
605+
)
577606

578607
instance._hooks_being_registered = False
579608

lib/crewai/src/crewai/project/wrappers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class CrewInstance(Protocol):
7272
__crew_metadata__: CrewMetadata
7373
_mcp_server_adapter: Any
7474
_all_methods: dict[str, Callable[..., Any]]
75+
_registered_hook_functions: list[tuple[str, Callable[..., Any]]]
76+
_hooks_being_registered: bool
7577
agents: list[Agent]
7678
tasks: list[Task]
7779
base_directory: Path

0 commit comments

Comments
 (0)