diff --git a/burr/core/application.py b/burr/core/application.py index 55f98acf6..dc8067c4b 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -2424,6 +2424,7 @@ def initialize_from( fork_from_app_id: str = None, fork_from_partition_key: str = None, fork_from_sequence_id: int = None, + override_state_values: Optional[dict] = None, ) -> "ApplicationBuilder[StateType]": """Initializes the application we will build from some prior state object. @@ -2460,6 +2461,7 @@ def initialize_from( self.fork_from_app_id = fork_from_app_id self.fork_from_partition_key = fork_from_partition_key self.fork_from_sequence_id = fork_from_sequence_id + self.override_state_values = override_state_values return self def with_state_persister( @@ -2614,6 +2616,9 @@ def _init_state_from_persister( # there was something last_position = load_result["position"] self.state = load_result["state"] + if self.override_state_values: + self.state = self.state.update(**self.override_state_values) + self.sequence_id = load_result["sequence_id"] status = load_result["status"] if self.resume_at_next_action: diff --git a/telemetry/ui/src/components/routes/app/GraphView.tsx b/telemetry/ui/src/components/routes/app/GraphView.tsx index ce85b823c..31a5f37e3 100644 --- a/telemetry/ui/src/components/routes/app/GraphView.tsx +++ b/telemetry/ui/src/components/routes/app/GraphView.tsx @@ -20,7 +20,7 @@ import { ActionModel, ApplicationModel, Step } from '../../../api'; import dagre from 'dagre'; -import React, { createContext, useCallback, useLayoutEffect, useRef, useState } from 'react'; +import React, { createContext, useLayoutEffect, useRef, useState } from 'react'; import ReactFlow, { BaseEdge, Controls, @@ -250,7 +250,10 @@ const getLayoutedElements = ( }); }; -const convertApplicationToGraph = (stateMachine: ApplicationModel): [NodeType[], EdgeType[]] => { +const convertApplicationToGraph = ( + stateMachine: ApplicationModel, + showInputs: boolean +): [NodeType[], EdgeType[]] => { const shouldDisplayInput = (input: string) => !input.startsWith('__'); const inputUniqueID = (action: ActionModel, input: string) => `${action.name}:${input}`; // Currently they're distinct by name @@ -285,10 +288,12 @@ const convertApplicationToGraph = (stateMachine: ApplicationModel): [NodeType[], markerEnd: { type: MarkerType.ArrowClosed, width: 20, height: 20 }, data: { from: transition.from_, to: transition.to, condition: transition.condition } })); - return [ - [...allActionNodes, ...allInputNodes], - [...allInputTransitions, ...allTransitionEdges] - ]; + return showInputs + ? [ + [...allActionNodes, ...allInputNodes], + [...allInputTransitions, ...allTransitionEdges] + ] + : [[...allActionNodes], [...allTransitionEdges]]; }; const nodeTypes = { @@ -317,34 +322,25 @@ export const _Graph = (props: { previousActions: Step[] | undefined; hoverAction: Step | undefined; }) => { - const [initialNodes, initialEdges] = React.useMemo(() => { - return convertApplicationToGraph(props.stateMachine); - }, [props.stateMachine]); + const [showInputs, setShowInputs] = useState(true); const [nodes, setNodes] = useState([]); const [edges, setEdges] = useState([]); const { fitView } = useReactFlow(); - const onLayout = useCallback( - ({ direction = 'TB', useInitialNodes = false }): void => { - const opts = { direction }; - const ns = useInitialNodes ? initialNodes : nodes; - const es = useInitialNodes ? initialEdges : edges; + useLayoutEffect(() => { + const [nextNodes, nextEdges] = convertApplicationToGraph(props.stateMachine, showInputs); - getLayoutedElements(ns, es, opts).then(({ nodes: layoutedNodes, edges: layoutedEdges }) => { + getLayoutedElements(nextNodes, nextEdges, { direction: 'TB' }).then( + ({ nodes: layoutedNodes, edges: layoutedEdges }) => { setNodes(layoutedNodes); setEdges(layoutedEdges); window.requestAnimationFrame(() => fitView()); - }); - }, - [nodes, edges] - ); - - useLayoutEffect(() => { - onLayout({ direction: 'TB', useInitialNodes: true }); - }, []); + } + ); + }, [showInputs, props.stateMachine, fitView]); return (
+ + State: assert sorted(halt_after) == ["test_action", "test_action_2"] assert halt_before == ["test_action"] assert inputs == {} + + +def test_initialize_from_applies_override_state_values(): + class FakeStateLoader(BaseStateLoader): + def load(self, partition_key, app_id, sequence_id): + return { + "state": State({"x": 1}), + "position": None, + "sequence_id": 0, + "status": "completed", + } + + def list_app_ids(self, partition_key): + return [] + + @action(reads=[], writes=[]) + def noop(state: State) -> State: + return state + + builder = ( + ApplicationBuilder() + .initialize_from( + initializer=FakeStateLoader(), + resume_at_next_action=False, + default_state={}, + default_entrypoint="noop", + override_state_values={"x": 100}, + ) + .with_actions(noop) + .with_transitions() + ) + + app = builder.build() + + assert app.state["x"] == 100