1616from abc import ABC , abstractmethod
1717from concurrent .futures import Future
1818from typing import Any , Iterator , List , Optional , Tuple , Union , overload
19+ from uuid import UUID
1920
2021from zenml .logger import get_logger
21- from zenml .models import ArtifactVersionResponse
22+ from zenml .models import ArtifactVersionResponse , StepRunResponse
2223
2324logger = get_logger (__name__ )
2425
@@ -89,24 +90,118 @@ def result(self) -> Any:
8990 """
9091
9192
92- class BaseStepRunFuture (BaseFuture ):
93- """Base step run future."""
93+ class _InlineStepFuture (BaseFuture ):
94+ """Future for an inline step run."""
95+
96+ def __init__ (
97+ self , wrapped : Future ["StepRunResponse" ], invocation_id : str
98+ ) -> None :
99+ """Initialize the inline step run future.
100+
101+ Args:
102+ wrapped: The wrapped future object.
103+ invocation_id: The invocation ID of the step run.
104+ """
105+ self ._wrapped = wrapped
106+ self .invocation_id = invocation_id
107+
108+ def running (self ) -> bool :
109+ """Check if the step run future is running.
110+
111+ Returns:
112+ True if the step run future is running, False otherwise.
113+ """
114+ return self ._wrapped .running ()
115+
116+ def result (self ) -> "StepRunResponse" :
117+ """Get the result of the step run future.
118+
119+ Returns:
120+ The result of the step run future.
121+ """
122+ return self ._wrapped .result ()
123+
124+
125+ class _IsolatedStepFuture (BaseFuture ):
126+ """Future for an inline step run."""
94127
95128 def __init__ (
96129 self ,
97- wrapped : Future [ StepRunOutputs ] ,
130+ pipeline_run_id : UUID ,
98131 invocation_id : str ,
132+ wrapped : Optional [Future ["StepRunResponse" ]] = None ,
133+ ) -> None :
134+ """Initialize the step run future.
135+
136+ Args:
137+ pipeline_run_id: The ID of the pipeline run.
138+ invocation_id: The invocation ID of the step run.
139+ wrapped: Optional future to wait for that submits the step run.
140+ """
141+ self ._wrapped = wrapped
142+ self .pipeline_run_id = pipeline_run_id
143+ self .invocation_id = invocation_id
144+
145+ def running (self ) -> bool :
146+ """Check if the isolated step future is running.
147+
148+ Returns:
149+ True if the isolated step future is running, False otherwise.
150+ """
151+ from zenml .execution .pipeline .dynamic .utils import get_latest_step_run
152+
153+ if self ._wrapped and self ._wrapped .running ():
154+ return True
155+
156+ step_run = get_latest_step_run (
157+ self .pipeline_run_id , self .invocation_id , hydrate = False
158+ )
159+
160+ return not step_run .status .is_finished
161+
162+ def result (self ) -> "StepRunResponse" :
163+ """Get the result of the step future.
164+
165+ Raises:
166+ RuntimeError: If the step failed.
167+
168+ Returns:
169+ The result of the step future.
170+ """
171+ from zenml .execution .pipeline .dynamic .utils import (
172+ wait_for_step_to_finish ,
173+ )
174+
175+ if self ._wrapped :
176+ # We first wait until the step run is submitted and only then
177+ # start monitoring the actual step.
178+ self ._wrapped .result ()
179+
180+ step_run = wait_for_step_to_finish (
181+ pipeline_run_id = self .pipeline_run_id , step_name = self .invocation_id
182+ )
183+
184+ if step_run .status .is_failed :
185+ raise RuntimeError (f"Step `{ self .invocation_id } ` failed." )
186+
187+ return step_run
188+
189+
190+ class BaseStepFuture (BaseFuture ):
191+ """Base step future."""
192+
193+ def __init__ (
194+ self ,
195+ wrapped : Union [_InlineStepFuture , _IsolatedStepFuture ],
99196 ** kwargs : Any ,
100197 ) -> None :
101198 """Initialize the dynamic step run future.
102199
103200 Args:
104201 wrapped: The wrapped future object.
105- invocation_id: The invocation ID of the step run.
106202 **kwargs: Additional keyword arguments.
107203 """
108204 self ._wrapped = wrapped
109- self ._invocation_id = invocation_id
110205
111206 @property
112207 def invocation_id (self ) -> str :
@@ -115,7 +210,7 @@ def invocation_id(self) -> str:
115210 Returns:
116211 The step run invocation ID.
117212 """
118- return self ._invocation_id
213+ return self ._wrapped . invocation_id
119214
120215 def running (self ) -> bool :
121216 """Check if the step run future is running.
@@ -126,20 +221,21 @@ def running(self) -> bool:
126221 return self ._wrapped .running ()
127222
128223
129- class ArtifactFuture (BaseStepRunFuture ):
224+ class ArtifactFuture (BaseStepFuture ):
130225 """Future for a step run output artifact."""
131226
132227 def __init__ (
133- self , wrapped : Future [StepRunOutputs ], invocation_id : str , index : int
228+ self ,
229+ wrapped : Union [_InlineStepFuture , _IsolatedStepFuture ],
230+ index : int ,
134231 ) -> None :
135232 """Initialize the future.
136233
137234 Args:
138235 wrapped: The wrapped future object.
139- invocation_id: The invocation ID of the step run.
140236 index: The index of the output artifact.
141237 """
142- super ().__init__ (wrapped = wrapped , invocation_id = invocation_id )
238+ super ().__init__ (wrapped = wrapped )
143239 self ._index = index
144240
145241 def result (self ) -> OutputArtifact :
@@ -151,14 +247,20 @@ def result(self) -> OutputArtifact:
151247 Returns:
152248 The output artifact.
153249 """
154- result = self ._wrapped .result ()
250+ step_run = self ._wrapped .result ()
251+ from zenml .execution .pipeline .dynamic .utils import (
252+ load_step_run_outputs ,
253+ )
254+
255+ result = load_step_run_outputs (step_run .id )
256+
155257 if isinstance (result , OutputArtifact ):
156258 return result
157259 elif isinstance (result , tuple ):
158260 return result [self ._index ]
159261 else :
160262 raise RuntimeError (
161- f"Step { self ._invocation_id } returned an invalid output: "
263+ f"Step { self .invocation_id } returned an invalid output: "
162264 f"{ result } ."
163265 )
164266
@@ -188,23 +290,21 @@ def chunk(self, index: int) -> "OutputArtifact":
188290 return self .result ().chunk (index = index )
189291
190292
191- class StepRunOutputsFuture ( BaseStepRunFuture ):
293+ class StepFuture ( BaseStepFuture ):
192294 """Future for a step run output."""
193295
194296 def __init__ (
195297 self ,
196- wrapped : Future [StepRunOutputs ],
197- invocation_id : str ,
298+ wrapped : Union [_InlineStepFuture , _IsolatedStepFuture ],
198299 output_keys : List [str ],
199300 ) -> None :
200301 """Initialize the future.
201302
202303 Args:
203304 wrapped: The wrapped future object.
204- invocation_id: The invocation ID of the step run.
205305 output_keys: The output keys of the step run.
206306 """
207- super ().__init__ (wrapped = wrapped , invocation_id = invocation_id )
307+ super ().__init__ (wrapped = wrapped )
208308 self ._output_keys = output_keys
209309
210310 def get_artifact (self , key : str ) -> ArtifactFuture :
@@ -221,31 +321,39 @@ def get_artifact(self, key: str) -> ArtifactFuture:
221321 """
222322 if key not in self ._output_keys :
223323 raise KeyError (
224- f"Step run { self ._invocation_id } does not have an output with "
324+ f"Step run { self .invocation_id } does not have an output with "
225325 f"the name: { key } ."
226326 )
227327
228328 return ArtifactFuture (
229329 wrapped = self ._wrapped ,
230- invocation_id = self ._invocation_id ,
231330 index = self ._output_keys .index (key ),
232331 )
233332
333+ def wait (self ) -> None :
334+ """Wait for the step to finish."""
335+ self ._wrapped .result ()
336+
234337 def artifacts (self ) -> StepRunOutputs :
235338 """Get the step run output artifacts.
236339
237340 Returns:
238341 The step run output artifacts.
239342 """
240- return self ._wrapped . result ()
343+ return self .result ()
241344
242345 def result (self ) -> StepRunOutputs :
243346 """Get the step run outputs this future represents.
244347
245348 Returns:
246349 The step run outputs.
247350 """
248- return self ._wrapped .result ()
351+ from zenml .execution .pipeline .dynamic .utils import (
352+ load_step_run_outputs ,
353+ )
354+
355+ step_run = self ._wrapped .result ()
356+ return load_step_run_outputs (step_run .id )
249357
250358 def load (self , disable_cache : bool = False ) -> Any :
251359 """Get the step run output artifact data.
@@ -297,15 +405,13 @@ def __getitem__(
297405
298406 return ArtifactFuture (
299407 wrapped = self ._wrapped ,
300- invocation_id = self ._invocation_id ,
301408 index = self ._output_keys .index (output_key ),
302409 )
303410 elif isinstance (key , slice ):
304411 output_keys = self ._output_keys [key ]
305412 return tuple (
306413 ArtifactFuture (
307414 wrapped = self ._wrapped ,
308- invocation_id = self ._invocation_id ,
309415 index = self ._output_keys .index (output_key ),
310416 )
311417 for output_key in output_keys
@@ -324,13 +430,12 @@ def __iter__(self) -> Any:
324430 """
325431 if not self ._output_keys :
326432 raise ValueError (
327- f"Step { self ._invocation_id } does not return any outputs."
433+ f"Step { self .invocation_id } does not return any outputs."
328434 )
329435
330436 for index in range (len (self ._output_keys )):
331437 yield ArtifactFuture (
332438 wrapped = self ._wrapped ,
333- invocation_id = self ._invocation_id ,
334439 index = index ,
335440 )
336441
@@ -346,7 +451,7 @@ def __len__(self) -> int:
346451class MapResultsFuture (BaseFuture ):
347452 """Future that represents the results of a `step.map/product(...)` call."""
348453
349- def __init__ (self , futures : List [StepRunOutputsFuture ]) -> None :
454+ def __init__ (self , futures : List [StepFuture ]) -> None :
350455 """Initialize the map results future.
351456
352457 Args:
@@ -417,14 +522,14 @@ def map_pipeline():
417522 return tuple (map (list , zip (* self .futures )))
418523
419524 @overload
420- def __getitem__ (self , key : int ) -> StepRunOutputsFuture : ...
525+ def __getitem__ (self , key : int ) -> StepFuture : ...
421526
422527 @overload
423- def __getitem__ (self , key : slice ) -> List [StepRunOutputsFuture ]: ...
528+ def __getitem__ (self , key : slice ) -> List [StepFuture ]: ...
424529
425530 def __getitem__ (
426531 self , key : Union [int , slice ]
427- ) -> Union [StepRunOutputsFuture , List [StepRunOutputsFuture ]]:
532+ ) -> Union [StepFuture , List [StepFuture ]]:
428533 """Get a step run future.
429534
430535 Args:
@@ -435,7 +540,7 @@ def __getitem__(
435540 """
436541 return self .futures [key ]
437542
438- def __iter__ (self ) -> Iterator [StepRunOutputsFuture ]:
543+ def __iter__ (self ) -> Iterator [StepFuture ]:
439544 """Iterate over the step run futures.
440545
441546 Yields:
@@ -452,6 +557,4 @@ def __len__(self) -> int:
452557 return len (self .futures )
453558
454559
455- AnyStepRunFuture = Union [
456- ArtifactFuture , StepRunOutputsFuture , MapResultsFuture
457- ]
560+ AnyStepFuture = Union [ArtifactFuture , StepFuture , MapResultsFuture ]
0 commit comments