Skip to content

Commit 10afb86

Browse files
authored
[BREAKING] Python: Refactor SharedState to State with sync methods and superstep caching (#3667)
* Refactor SharedState to State with sync methods and superstep caching * Fixes * Address PR feedback * Remove dead links * Fix lab test import
1 parent 4e25917 commit 10afb86

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1973
-1726
lines changed

python/packages/core/agent_framework/_workflows/_agent_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR
330330
Returns:
331331
The complete AgentResponse, or None if waiting for user input.
332332
"""
333-
run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)
333+
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
334334

335335
response = await self._agent.run(
336336
self._cache,
@@ -357,7 +357,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp
357357
Returns:
358358
The complete AgentResponse, or None if waiting for user input.
359359
"""
360-
run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)
360+
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}
361361

362362
updates: list[AgentResponseUpdate] = []
363363
user_input_requests: list[Content] = []

python/packages/core/agent_framework/_workflows/_checkpoint.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ class WorkflowCheckpoint:
2626
workflow_id: Identifier of the workflow this checkpoint belongs to
2727
timestamp: ISO 8601 timestamp when checkpoint was created
2828
messages: Messages exchanged between executors
29-
shared_state: Complete shared state including user data and executor states.
30-
Executor states are stored under the reserved key '_executor_state'.
29+
state: Committed workflow state including user data and executor states.
30+
This contains only committed state; pending state changes are not
31+
included in checkpoints. Executor states are stored under the
32+
reserved key '_executor_state'.
3133
iteration_count: Current iteration number when checkpoint was created
3234
metadata: Additional metadata (e.g., superstep info, graph signature)
3335
version: Checkpoint format version
3436
3537
Note:
36-
The shared_state dict may contain reserved keys managed by the framework.
37-
See SharedState class documentation for details on reserved keys.
38+
The state dict may contain reserved keys managed by the framework.
39+
See State class documentation for details on reserved keys.
3840
"""
3941

4042
checkpoint_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@@ -43,7 +45,7 @@ class WorkflowCheckpoint:
4345

4446
# Core workflow state
4547
messages: dict[str, list[dict[str, Any]]] = field(default_factory=dict) # type: ignore[misc]
46-
shared_state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc]
48+
state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc]
4749
pending_request_info_events: dict[str, dict[str, Any]] = field(default_factory=dict) # type: ignore[misc]
4850

4951
# Runtime state

python/packages/core/agent_framework/_workflows/_checkpoint_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class WorkflowCheckpointSummary:
2525

2626
def get_checkpoint_summary(checkpoint: WorkflowCheckpoint) -> WorkflowCheckpointSummary:
2727
targets = sorted(checkpoint.messages.keys())
28-
executor_ids = sorted(checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {}).keys())
28+
executor_ids = sorted(checkpoint.state.get(EXECUTOR_STATE_KEY, {}).keys())
2929
pending_request_info_events = [
3030
RequestInfoEvent.from_dict(request) for request in checkpoint.pending_request_info_events.values()
3131
]

python/packages/core/agent_framework/_workflows/_const.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# Default maximum iterations for workflow execution.
44
DEFAULT_MAX_ITERATIONS = 100
55

6-
# Key used to store executor state in shared state.
6+
# Key used to store executor state in state.
77
EXECUTOR_STATE_KEY = "_executor_state"
88

99
# Source identifier for internal workflow messages.
1010
INTERNAL_SOURCE_PREFIX = "internal"
1111

12-
# SharedState key for storing run kwargs that should be passed to agent invocations.
12+
# State key for storing run kwargs that should be passed to agent invocations.
1313
# Used by all orchestration patterns (Sequential, Concurrent, GroupChat, Handoff, Magentic)
1414
# to pass kwargs from workflow.run_stream() through to agent.run_stream() and @tool functions.
1515
WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs"

python/packages/core/agent_framework/_workflows/_edge_runner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from ._executor import Executor
2121
from ._runner_context import Message, RunnerContext
22-
from ._shared_state import SharedState
22+
from ._state import State
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -38,12 +38,12 @@ def __init__(self, edge_group: EdgeGroup, executors: dict[str, Executor]) -> Non
3838
self._executors = executors
3939

4040
@abstractmethod
41-
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
41+
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
4242
"""Send a message through the edge group.
4343
4444
Args:
4545
message: The message to send.
46-
shared_state: The shared state to use for holding data.
46+
state: The workflow state.
4747
ctx: The context for the runner.
4848
4949
Returns:
@@ -63,7 +63,7 @@ async def _execute_on_target(
6363
target_id: str,
6464
source_ids: list[str],
6565
message: Message,
66-
shared_state: SharedState,
66+
state: State,
6767
ctx: RunnerContext,
6868
) -> None:
6969
"""Execute a message on a target executor with trace context."""
@@ -76,7 +76,7 @@ async def _execute_on_target(
7676
await target_executor.execute(
7777
message,
7878
source_ids, # source_executor_ids
79-
shared_state, # shared_state
79+
state, # state
8080
ctx, # runner_context
8181
trace_contexts=message.trace_contexts, # Pass trace contexts
8282
source_span_ids=message.source_span_ids, # Pass source span IDs for linking
@@ -90,7 +90,7 @@ def __init__(self, edge_group: SingleEdgeGroup | InternalEdgeGroup, executors: d
9090
super().__init__(edge_group, executors)
9191
self._edge = edge_group.edges[0]
9292

93-
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
93+
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
9494
"""Send a message through the single edge."""
9595
should_execute = False
9696
target_id: str | None = None
@@ -144,7 +144,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R
144144

145145
# Execute outside the span
146146
if should_execute and target_id and source_id:
147-
await self._execute_on_target(target_id, [source_id], message, shared_state, ctx)
147+
await self._execute_on_target(target_id, [source_id], message, state, ctx)
148148
return True
149149

150150
return False
@@ -162,7 +162,7 @@ def __init__(self, edge_group: FanOutEdgeGroup, executors: dict[str, Executor])
162162
Callable[[Any, list[str]], list[str]] | None, getattr(edge_group, "selection_func", None)
163163
)
164164

165-
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
165+
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
166166
"""Send a message through all edges in the fan-out edge group."""
167167
deliverable_edges: list[Edge] = []
168168
single_target_edge: Edge | None = None
@@ -253,14 +253,14 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R
253253
# Execute outside the span
254254
if single_target_edge:
255255
await self._execute_on_target(
256-
single_target_edge.target_id, [single_target_edge.source_id], message, shared_state, ctx
256+
single_target_edge.target_id, [single_target_edge.source_id], message, state, ctx
257257
)
258258
return True
259259

260260
if deliverable_edges:
261261

262262
async def send_to_edge(edge: Edge) -> bool:
263-
await self._execute_on_target(edge.target_id, [edge.source_id], message, shared_state, ctx)
263+
await self._execute_on_target(edge.target_id, [edge.source_id], message, state, ctx)
264264
return True
265265

266266
tasks = [send_to_edge(edge) for edge in deliverable_edges]
@@ -285,7 +285,7 @@ def __init__(self, edge_group: FanInEdgeGroup, executors: dict[str, Executor]) -
285285
# Key is the source executor ID, value is a list of messages
286286
self._buffer: dict[str, list[Message]] = defaultdict(list)
287287

288-
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
288+
async def send_message(self, message: Message, state: State, ctx: RunnerContext) -> bool:
289289
"""Send a message through all edges in the fan-in edge group."""
290290
execution_data: dict[str, Any] | None = None
291291
with create_edge_group_processing_span(
@@ -362,7 +362,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R
362362
# Execute outside the span if needed
363363
if execution_data:
364364
await self._execute_on_target(
365-
execution_data["target_id"], execution_data["source_ids"], execution_data["message"], shared_state, ctx
365+
execution_data["target_id"], execution_data["source_ids"], execution_data["message"], state, ctx
366366
)
367367
return True
368368

python/packages/core/agent_framework/_workflows/_executor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._model_utils import DictConvertible
2121
from ._request_info_mixin import RequestInfoMixin
2222
from ._runner_context import Message, MessageType, RunnerContext
23-
from ._shared_state import SharedState
23+
from ._state import State
2424
from ._typing_utils import is_instance_of, normalize_type_to_list, resolve_type_annotation
2525
from ._workflow_context import WorkflowContext, validate_workflow_context_annotation
2626

@@ -221,7 +221,7 @@ async def execute(
221221
self,
222222
message: Any,
223223
source_executor_ids: list[str],
224-
shared_state: SharedState,
224+
state: State,
225225
runner_context: RunnerContext,
226226
trace_contexts: list[dict[str, str]] | None = None,
227227
source_span_ids: list[str] | None = None,
@@ -234,7 +234,7 @@ async def execute(
234234
Args:
235235
message: The message to be processed by the executor.
236236
source_executor_ids: The IDs of the source executors that sent messages to this executor.
237-
shared_state: The shared state for the workflow.
237+
state: The state for the workflow.
238238
runner_context: The runner context that provides methods to send messages and events.
239239
trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation.
240240
source_span_ids: Optional source span IDs from multiple sources for linking.
@@ -262,7 +262,7 @@ async def execute(
262262
# Create the appropriate WorkflowContext based on handler specs
263263
context = self._create_context_for_handler(
264264
source_executor_ids=source_executor_ids,
265-
shared_state=shared_state,
265+
state=state,
266266
runner_context=runner_context,
267267
trace_contexts=trace_contexts,
268268
source_span_ids=source_span_ids,
@@ -295,7 +295,7 @@ async def execute(
295295
def _create_context_for_handler(
296296
self,
297297
source_executor_ids: list[str],
298-
shared_state: SharedState,
298+
state: State,
299299
runner_context: RunnerContext,
300300
trace_contexts: list[dict[str, str]] | None = None,
301301
source_span_ids: list[str] | None = None,
@@ -305,7 +305,7 @@ def _create_context_for_handler(
305305
306306
Args:
307307
source_executor_ids: The IDs of the source executors that sent messages to this executor.
308-
shared_state: The shared state for the workflow.
308+
state: The state for the workflow.
309309
runner_context: The runner context that provides methods to send messages and events.
310310
trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation.
311311
source_span_ids: Optional source span IDs from multiple sources for linking.
@@ -318,7 +318,7 @@ def _create_context_for_handler(
318318
return WorkflowContext(
319319
executor=self,
320320
source_executor_ids=source_executor_ids,
321-
shared_state=shared_state,
321+
state=state,
322322
runner_context=runner_context,
323323
trace_contexts=trace_contexts,
324324
source_span_ids=source_span_ids,

python/packages/core/agent_framework/_workflows/_runner.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
Message,
2828
RunnerContext,
2929
)
30-
from ._shared_state import SharedState
30+
from ._state import State
3131

3232
logger = logging.getLogger(__name__)
3333

@@ -39,17 +39,17 @@ def __init__(
3939
self,
4040
edge_groups: Sequence[EdgeGroup],
4141
executors: dict[str, Executor],
42-
shared_state: SharedState,
42+
state: State,
4343
ctx: RunnerContext,
4444
max_iterations: int = 100,
4545
workflow_id: str | None = None,
4646
) -> None:
47-
"""Initialize the runner with edges, shared state, and context.
47+
"""Initialize the runner with edges, state, and context.
4848
4949
Args:
5050
edge_groups: The edge groups of the workflow.
5151
executors: Map of executor IDs to executor instances.
52-
shared_state: The shared state for the workflow.
52+
state: The state for the workflow.
5353
ctx: The runner context for the workflow.
5454
max_iterations: The maximum number of iterations to run.
5555
workflow_id: The workflow ID for checkpointing.
@@ -60,7 +60,7 @@ def __init__(
6060
self._ctx = ctx
6161
self._iteration = 0
6262
self._max_iterations = max_iterations
63-
self._shared_state = shared_state
63+
self._state = state
6464
self._workflow_id = workflow_id
6565
self._running = False
6666
self._resumed_from_checkpoint = False # Track whether we resumed
@@ -141,6 +141,9 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
141141

142142
logger.info(f"Completed superstep {self._iteration}")
143143

144+
# Commit pending state changes at superstep boundary
145+
self._state.commit()
146+
144147
# Create checkpoint after each superstep iteration
145148
await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}")
146149

@@ -164,7 +167,7 @@ async def _deliver_messages(source_executor_id: str, messages: list[Message]) ->
164167

165168
async def _deliver_message_inner(edge_runner: EdgeRunner, message: Message) -> bool:
166169
"""Inner loop to deliver a single message through an edge runner."""
167-
return await edge_runner.send_message(message, self._shared_state, self._ctx)
170+
return await edge_runner.send_message(message, self._state, self._ctx)
168171

169172
def _normalize_message_payload(message: Message) -> None:
170173
data = message.data
@@ -212,7 +215,7 @@ async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | Non
212215
if self.graph_signature_hash:
213216
metadata["graph_signature"] = self.graph_signature_hash
214217
checkpoint_id = await self._ctx.create_checkpoint(
215-
self._shared_state,
218+
self._state,
216219
self._iteration,
217220
metadata=metadata,
218221
)
@@ -271,9 +274,9 @@ async def restore_from_checkpoint(
271274
)
272275

273276
self._workflow_id = checkpoint.workflow_id
274-
# Restore shared state
275-
await self._shared_state.import_state(decode_checkpoint_value(checkpoint.shared_state))
276-
# Restore executor states using the restored shared state
277+
# Restore state
278+
self._state.import_state(decode_checkpoint_value(checkpoint.state))
279+
# Restore executor states using the restored state
277280
await self._restore_executor_states()
278281
# Apply the checkpoint to the context
279282
await self._ctx.apply_checkpoint(checkpoint)
@@ -346,11 +349,11 @@ async def _restore_executor_states(self) -> None:
346349
This method will try the backward compatibility behavior first; if that does not restore state,
347350
it falls back to the updated behavior.
348351
"""
349-
has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
352+
has_executor_states = self._state.has(EXECUTOR_STATE_KEY)
350353
if not has_executor_states:
351354
return
352355

353-
executor_states = await self._shared_state.get(EXECUTOR_STATE_KEY)
356+
executor_states = self._state.get(EXECUTOR_STATE_KEY)
354357
if not isinstance(executor_states, dict):
355358
raise WorkflowCheckpointException("Executor states in shared state is not a dictionary. Unable to restore.")
356359

@@ -416,19 +419,15 @@ def _mark_resumed(self, iteration: int) -> None:
416419
self._iteration = iteration
417420

418421
async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
419-
"""Store executor state in shared state under a reserved key.
422+
"""Store executor state in state under a reserved key.
420423
421424
Executors call this with a JSON-serializable dict capturing the minimal
422425
state needed to resume. It replaces any previously stored state.
423426
"""
424-
has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
425-
if has_existing_states:
426-
existing_states = await self._shared_state.get(EXECUTOR_STATE_KEY)
427-
else:
428-
existing_states = {}
427+
existing_states = self._state.get(EXECUTOR_STATE_KEY, {})
429428

430429
if not isinstance(existing_states, dict):
431-
raise WorkflowCheckpointException("Existing executor states in shared state is not a dictionary.")
430+
raise WorkflowCheckpointException("Existing executor states in state is not a dictionary.")
432431

433432
existing_states[executor_id] = state
434-
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
433+
self._state.set(EXECUTOR_STATE_KEY, existing_states)

0 commit comments

Comments
 (0)