Skip to content

Commit a5dadf8

Browse files
Fix mypy type errors and update purview middleware for MiddlewareTermination
- Fix type annotations in _middleware.py for context.result assignments - Fix type annotations in _agents.py for chat_client.get_response calls - Update ResponseStream result_hooks type to allow Awaitable[TFinal | None] - Convert Sequence to list in FunctionInvocationLayer init - Add type: ignore comments for cooperative inheritance patterns - Update AgentRunContext.result type to use ResponseStream instead of AsyncIterable - Update purview middleware to use raise MiddlewareTermination instead of terminate flag - Ensure MiddlewareTermination is not caught by generic Exception handlers - Update purview tests to expect MiddlewareTermination
1 parent 70fecbc commit a5dadf8

File tree

8 files changed

+331
-333
lines changed

8 files changed

+331
-333
lines changed

python/packages/core/agent_framework/_agents.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ async def _run_non_streaming() -> AgentResponse[Any]:
874874
options=options,
875875
kwargs=kwargs,
876876
)
877-
response = await self.chat_client.get_response(
877+
response = await self.chat_client.get_response( # type: ignore[call-overload]
878878
messages=ctx["thread_messages"],
879879
stream=False,
880880
options=ctx["chat_options"],
@@ -944,8 +944,8 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]:
944944
options=options,
945945
kwargs=kwargs,
946946
)
947-
ctx = ctx_holder["ctx"]
948-
return self.chat_client.get_response(
947+
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
948+
return self.chat_client.get_response( # type: ignore[call-overload, no-any-return]
949949
messages=ctx["thread_messages"],
950950
stream=True,
951951
options=ctx["chat_options"],

python/packages/core/agent_framework/_middleware.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class AgentRunContext:
125125
result: Agent execution result. Can be observed after calling ``next()``
126126
to see the actual execution result or can be set to override the execution result.
127127
For non-streaming: should be AgentResponse.
128-
For streaming: should be AsyncIterable[AgentResponseUpdate].
128+
For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse].
129129
kwargs: Additional keyword arguments passed to the agent run method.
130130
131131
Examples:
@@ -160,7 +160,7 @@ def __init__(
160160
options: Mapping[str, Any] | None = None,
161161
stream: bool = False,
162162
metadata: Mapping[str, Any] | None = None,
163-
result: AgentResponse | AsyncIterable[AgentResponseUpdate] | None = None,
163+
result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None,
164164
kwargs: Mapping[str, Any] | None = None,
165165
stream_transform_hooks: Sequence[
166166
Callable[[AgentResponseUpdate], AgentResponseUpdate | Awaitable[AgentResponseUpdate]]
@@ -767,7 +767,7 @@ async def execute(
767767
The agent response after processing through all middleware.
768768
"""
769769
if not self._middleware:
770-
context.result = final_handler(context)
770+
context.result = final_handler(context) # type: ignore[assignment]
771771
if isinstance(context.result, Awaitable):
772772
context.result = await context.result
773773
return context.result
@@ -776,7 +776,7 @@ def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[Non
776776
if index >= len(self._middleware):
777777

778778
async def final_wrapper(c: AgentRunContext) -> None:
779-
c.result = final_handler(c)
779+
c.result = final_handler(c) # type: ignore[assignment]
780780
if inspect.isawaitable(c.result):
781781
c.result = await c.result
782782

@@ -904,7 +904,7 @@ async def execute(
904904
final_handler: Callable[
905905
[ChatContext], Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]
906906
],
907-
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
907+
) -> ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None:
908908
"""Execute the chat middleware pipeline.
909909
910910
Args:
@@ -915,7 +915,7 @@ async def execute(
915915
The chat response after processing through all middleware.
916916
"""
917917
if not self._middleware:
918-
context.result = final_handler(context)
918+
context.result = final_handler(context) # type: ignore[assignment]
919919
if isinstance(context.result, Awaitable):
920920
context.result = await context.result
921921
if context.stream and not isinstance(context.result, ResponseStream):
@@ -926,7 +926,7 @@ def create_next_handler(index: int) -> Callable[[ChatContext], Awaitable[None]]:
926926
if index >= len(self._middleware):
927927

928928
async def final_wrapper(c: ChatContext) -> None:
929-
c.result = final_handler(c)
929+
c.result = final_handler(c) # type: ignore[assignment]
930930
if inspect.isawaitable(c.result):
931931
c.result = await c.result
932932

@@ -1027,15 +1027,15 @@ def get_response(
10271027
*middleware["chat"],
10281028
)
10291029
if not pipeline.has_middlewares:
1030-
return super_get_response(
1030+
return super_get_response( # type: ignore[no-any-return]
10311031
messages=messages,
10321032
stream=stream,
10331033
options=options,
10341034
**kwargs,
10351035
)
10361036

10371037
context = ChatContext(
1038-
chat_client=self,
1038+
chat_client=self, # type: ignore[arg-type]
10391039
messages=prepare_messages(messages),
10401040
options=options,
10411041
stream=stream,
@@ -1063,13 +1063,13 @@ async def _execute_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]:
10631063
return ResponseStream.from_awaitable(_execute_stream())
10641064

10651065
# For non-streaming, return the coroutine directly
1066-
return _execute()
1066+
return _execute() # type: ignore[return-value]
10671067

10681068
def _middleware_handler(
10691069
self, context: ChatContext
10701070
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
10711071
"""Internal middleware handler to adapt to pipeline."""
1072-
return super().get_response(
1072+
return super().get_response( # type: ignore[misc, no-any-return]
10731073
messages=context.messages,
10741074
stream=context.stream,
10751075
options=context.options or {},
@@ -1089,7 +1089,7 @@ def __init__(
10891089
middleware_list = categorize_middleware(middleware)
10901090
self.agent_middleware = middleware_list["agent"]
10911091
# Pass middleware to super so BaseAgent can store it for dynamic rebuild
1092-
super().__init__(*args, middleware=middleware, **kwargs)
1092+
super().__init__(*args, middleware=middleware, **kwargs) # type: ignore[call-arg]
10931093
if chat_client := getattr(self, "chat_client", None):
10941094
client_chat_middleware = getattr(chat_client, "chat_middleware", [])
10951095
client_chat_middleware.extend(middleware_list["chat"])
@@ -1157,10 +1157,10 @@ def run(
11571157

11581158
# Execute with middleware if available
11591159
if not pipeline.has_middlewares:
1160-
return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs)
1160+
return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return]
11611161

11621162
context = AgentRunContext(
1163-
agent=self,
1163+
agent=self, # type: ignore[arg-type]
11641164
messages=prepare_messages(messages),
11651165
thread=thread,
11661166
options=options,
@@ -1189,12 +1189,12 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse
11891189
return ResponseStream.from_awaitable(_execute_stream())
11901190

11911191
# For non-streaming, return the coroutine directly
1192-
return _execute()
1192+
return _execute() # type: ignore[return-value]
11931193

11941194
def _middleware_handler(
11951195
self, context: AgentRunContext
11961196
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
1197-
return super().run(
1197+
return super().run( # type: ignore[misc, no-any-return]
11981198
context.messages,
11991199
stream=context.stream,
12001200
thread=context.thread,

python/packages/core/agent_framework/_tools.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2047,7 +2047,9 @@ def __init__(
20472047
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
20482048
**kwargs: Any,
20492049
) -> None:
2050-
self.function_middleware: list[FunctionMiddlewareTypes] = function_middleware or []
2050+
self.function_middleware: list[FunctionMiddlewareTypes] = (
2051+
list(function_middleware) if function_middleware else []
2052+
)
20512053
self.function_invocation_configuration = normalize_function_invocation_configuration(
20522054
function_invocation_configuration
20532055
)

python/packages/core/agent_framework/_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,7 +2464,7 @@ def __init__(
24642464
finalizer: Callable[[Sequence[TUpdate]], TFinal | Awaitable[TFinal]] | None = None,
24652465
transform_hooks: list[Callable[[TUpdate], TUpdate | Awaitable[TUpdate] | None]] | None = None,
24662466
cleanup_hooks: list[Callable[[], Awaitable[None] | None]] | None = None,
2467-
result_hooks: list[Callable[[TFinal], TFinal | Awaitable[TFinal] | None]] | None = None,
2467+
result_hooks: list[Callable[[TFinal], TFinal | Awaitable[TFinal | None] | None]] | None = None,
24682468
) -> None:
24692469
"""A Async Iterable stream of updates.
24702470
@@ -2489,7 +2489,7 @@ def __init__(
24892489
self._transform_hooks: list[Callable[[TUpdate], TUpdate | Awaitable[TUpdate] | None]] = (
24902490
transform_hooks if transform_hooks is not None else []
24912491
)
2492-
self._result_hooks: list[Callable[[TFinal], TFinal | Awaitable[TFinal] | None]] = (
2492+
self._result_hooks: list[Callable[[TFinal], TFinal | Awaitable[TFinal | None] | None]] = (
24932493
result_hooks if result_hooks is not None else []
24942494
)
24952495
self._cleanup_hooks: list[Callable[[], Awaitable[None] | None]] = (
@@ -2748,7 +2748,7 @@ def with_transform_hook(
27482748

27492749
def with_result_hook(
27502750
self,
2751-
hook: Callable[[TFinal], TFinal | Awaitable[TFinal] | None],
2751+
hook: Callable[[TFinal], TFinal | Awaitable[TFinal | None] | None],
27522752
) -> ResponseStream[TUpdate, TFinal]:
27532753
"""Register a result hook executed after finalization."""
27542754
self._result_hooks.append(hook)

python/packages/purview/agent_framework_purview/_middleware.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Awaitable, Callable
44

5-
from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware
5+
from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware, MiddlewareTermination
66
from agent_framework._logging import get_logger
77
from azure.core.credentials import TokenCredential
88
from azure.core.credentials_async import AsyncTokenCredential
@@ -62,8 +62,9 @@ async def process(
6262
context.result = AgentResponse(
6363
messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_prompt_message)]
6464
)
65-
context.terminate = True
66-
return
65+
raise MiddlewareTermination
66+
except MiddlewareTermination:
67+
raise
6768
except PurviewPaymentRequiredError as ex:
6869
logger.error(f"Purview payment required error in policy pre-check: {ex}")
6970
if not self._settings.ignore_payment_required:
@@ -151,8 +152,9 @@ async def process(
151152

152153
blocked_message = ChatMessage(role="system", text=self._settings.blocked_prompt_message)
153154
context.result = ChatResponse(messages=[blocked_message])
154-
context.terminate = True
155-
return
155+
raise MiddlewareTermination
156+
except MiddlewareTermination:
157+
raise
156158
except PurviewPaymentRequiredError as ex:
157159
logger.error(f"Purview payment required error in policy pre-check: {ex}")
158160
if not self._settings.ignore_payment_required:

python/packages/purview/tests/test_chat_middleware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import AsyncMock, MagicMock, patch
66

77
import pytest
8-
from agent_framework import ChatContext, ChatMessage, Role
8+
from agent_framework import ChatContext, ChatMessage, MiddlewareTermination, Role
99
from azure.core.credentials import AccessToken
1010

1111
from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings
@@ -71,8 +71,8 @@ async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat
7171
async def mock_next(ctx: ChatContext) -> None: # should not run
7272
raise AssertionError("next should not be called when prompt blocked")
7373

74-
await middleware.process(chat_context, mock_next)
75-
assert chat_context.terminate
74+
with pytest.raises(MiddlewareTermination):
75+
await middleware.process(chat_context, mock_next)
7676
assert chat_context.result
7777
assert hasattr(chat_context.result, "messages")
7878
msg = chat_context.result.messages[0]

python/packages/purview/tests/test_middleware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import AsyncMock, MagicMock, patch
66

77
import pytest
8-
from agent_framework import AgentResponse, AgentRunContext, ChatMessage, Role
8+
from agent_framework import AgentResponse, AgentRunContext, ChatMessage, MiddlewareTermination, Role
99
from azure.core.credentials import AccessToken
1010

1111
from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings
@@ -79,11 +79,11 @@ async def mock_next(ctx: AgentRunContext) -> None:
7979
nonlocal next_called
8080
next_called = True
8181

82-
await middleware.process(context, mock_next)
82+
with pytest.raises(MiddlewareTermination):
83+
await middleware.process(context, mock_next)
8384

8485
assert not next_called
8586
assert context.result is not None
86-
assert context.terminate
8787
assert len(context.result.messages) == 1
8888
assert context.result.messages[0].role == Role.SYSTEM
8989
assert "blocked by policy" in context.result.messages[0].text.lower()

0 commit comments

Comments
 (0)