Skip to content

Commit 6dc63a3

Browse files
committed
Improved dynamic pipeline step monitoring
1 parent 09ca399 commit 6dc63a3

File tree

20 files changed

+1327
-407
lines changed

20 files changed

+1327
-407
lines changed

src/zenml/config/step_run_info.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from zenml.config.pipeline_configurations import PipelineConfiguration
2121
from zenml.config.step_configurations import StepConfiguration, StepSpec
2222
from zenml.logger import get_logger
23-
from zenml.models import PipelineSnapshotResponse
23+
from zenml.models import PipelineSnapshotResponse, StepRunResponse
2424

2525
logger = get_logger(__name__)
2626

@@ -37,6 +37,7 @@ class StepRunInfo(FrozenBaseModel):
3737
spec: StepSpec
3838
pipeline: PipelineConfiguration
3939
snapshot: PipelineSnapshotResponse
40+
step_run: StepRunResponse
4041

4142
force_write_logs: Callable[..., Any]
4243

src/zenml/constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,23 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
128128
return default
129129

130130

131+
def handle_float_env_var(var: str, default: float = 0.0) -> float:
132+
"""Converts normal env var to float.
133+
134+
Args:
135+
var: The environment variable to convert.
136+
default: The default value to return if the env var is not set.
137+
138+
Returns:
139+
The converted value.
140+
"""
141+
value = os.getenv(var, "")
142+
try:
143+
return float(value)
144+
except (ValueError, TypeError):
145+
return default
146+
147+
131148
# Global constants
132149
APP_NAME = "zenml"
133150

@@ -220,6 +237,14 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
220237
ENV_ZENML_DEFAULT_OUTPUT = "ZENML_DEFAULT_OUTPUT"
221238
ENV_ZENML_CLI_COLUMN_WIDTH = "ZENML_CLI_COLUMN_WIDTH"
222239

240+
ENV_ZENML_DYNAMIC_PIPELINE_WORKER_COUNT = "ZENML_DYNAMIC_PIPELINE_WORKER_COUNT"
241+
ENV_ZENML_DYNAMIC_PIPELINE_MONITORING_INTERVAL = (
242+
"ZENML_DYNAMIC_PIPELINE_MONITORING_INTERVAL"
243+
)
244+
ENV_ZENML_DYNAMIC_PIPELINE_MONITORING_DELAY = (
245+
"ZENML_DYNAMIC_PIPELINE_MONITORING_DELAY"
246+
)
247+
223248
# Logging variables
224249
IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False)
225250

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 137 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from abc import ABC, abstractmethod
1717
from concurrent.futures import Future
1818
from typing import Any, Iterator, List, Optional, Tuple, Union, overload
19+
from uuid import UUID
1920

2021
from zenml.logger import get_logger
21-
from zenml.models import ArtifactVersionResponse
22+
from zenml.models import ArtifactVersionResponse, StepRunResponse
2223

2324
logger = get_logger(__name__)
2425

@@ -89,24 +90,118 @@ def result(self) -> Any:
8990
"""
9091

9192

92-
class BaseStepRunFuture(BaseFuture):
93-
"""Base step run future."""
93+
class _InlineStepFuture(BaseFuture):
94+
"""Future for an inline step run."""
95+
96+
def __init__(
97+
self, wrapped: Future["StepRunResponse"], invocation_id: str
98+
) -> None:
99+
"""Initialize the inline step run future.
100+
101+
Args:
102+
wrapped: The wrapped future object.
103+
invocation_id: The invocation ID of the step run.
104+
"""
105+
self._wrapped = wrapped
106+
self.invocation_id = invocation_id
107+
108+
def running(self) -> bool:
109+
"""Check if the step run future is running.
110+
111+
Returns:
112+
True if the step run future is running, False otherwise.
113+
"""
114+
return self._wrapped.running()
115+
116+
def result(self) -> "StepRunResponse":
117+
"""Get the result of the step run future.
118+
119+
Returns:
120+
The result of the step run future.
121+
"""
122+
return self._wrapped.result()
123+
124+
125+
class _IsolatedStepFuture(BaseFuture):
126+
"""Future for an inline step run."""
94127

95128
def __init__(
96129
self,
97-
wrapped: Future[StepRunOutputs],
130+
pipeline_run_id: UUID,
98131
invocation_id: str,
132+
wrapped: Optional[Future["StepRunResponse"]] = None,
133+
) -> None:
134+
"""Initialize the step run future.
135+
136+
Args:
137+
pipeline_run_id: The ID of the pipeline run.
138+
invocation_id: The invocation ID of the step run.
139+
wrapped: Optional future to wait for that submits the step run.
140+
"""
141+
self._wrapped = wrapped
142+
self.pipeline_run_id = pipeline_run_id
143+
self.invocation_id = invocation_id
144+
145+
def running(self) -> bool:
146+
"""Check if the isolated step future is running.
147+
148+
Returns:
149+
True if the isolated step future is running, False otherwise.
150+
"""
151+
from zenml.execution.pipeline.dynamic.utils import get_latest_step_run
152+
153+
if self._wrapped and self._wrapped.running():
154+
return True
155+
156+
step_run = get_latest_step_run(
157+
self.pipeline_run_id, self.invocation_id, hydrate=False
158+
)
159+
160+
return not step_run.status.is_finished
161+
162+
def result(self) -> "StepRunResponse":
163+
"""Get the result of the step future.
164+
165+
Raises:
166+
RuntimeError: If the step failed.
167+
168+
Returns:
169+
The result of the step future.
170+
"""
171+
from zenml.execution.pipeline.dynamic.utils import (
172+
wait_for_step_to_finish,
173+
)
174+
175+
if self._wrapped:
176+
# We first wait until the step run is submitted and only then
177+
# start monitoring the actual step.
178+
self._wrapped.result()
179+
180+
step_run = wait_for_step_to_finish(
181+
pipeline_run_id=self.pipeline_run_id, step_name=self.invocation_id
182+
)
183+
184+
if step_run.status.is_failed:
185+
raise RuntimeError(f"Step `{self.invocation_id}` failed.")
186+
187+
return step_run
188+
189+
190+
class BaseStepFuture(BaseFuture):
191+
"""Base step future."""
192+
193+
def __init__(
194+
self,
195+
wrapped: Union[_InlineStepFuture, _IsolatedStepFuture],
99196
**kwargs: Any,
100197
) -> None:
101198
"""Initialize the dynamic step run future.
102199
103200
Args:
104201
wrapped: The wrapped future object.
105-
invocation_id: The invocation ID of the step run.
106202
**kwargs: Additional keyword arguments.
107203
"""
108204
self._wrapped = wrapped
109-
self._invocation_id = invocation_id
110205

111206
@property
112207
def invocation_id(self) -> str:
@@ -115,7 +210,7 @@ def invocation_id(self) -> str:
115210
Returns:
116211
The step run invocation ID.
117212
"""
118-
return self._invocation_id
213+
return self._wrapped.invocation_id
119214

120215
def running(self) -> bool:
121216
"""Check if the step run future is running.
@@ -126,20 +221,21 @@ def running(self) -> bool:
126221
return self._wrapped.running()
127222

128223

129-
class ArtifactFuture(BaseStepRunFuture):
224+
class ArtifactFuture(BaseStepFuture):
130225
"""Future for a step run output artifact."""
131226

132227
def __init__(
133-
self, wrapped: Future[StepRunOutputs], invocation_id: str, index: int
228+
self,
229+
wrapped: Union[_InlineStepFuture, _IsolatedStepFuture],
230+
index: int,
134231
) -> None:
135232
"""Initialize the future.
136233
137234
Args:
138235
wrapped: The wrapped future object.
139-
invocation_id: The invocation ID of the step run.
140236
index: The index of the output artifact.
141237
"""
142-
super().__init__(wrapped=wrapped, invocation_id=invocation_id)
238+
super().__init__(wrapped=wrapped)
143239
self._index = index
144240

145241
def result(self) -> OutputArtifact:
@@ -151,14 +247,20 @@ def result(self) -> OutputArtifact:
151247
Returns:
152248
The output artifact.
153249
"""
154-
result = self._wrapped.result()
250+
step_run = self._wrapped.result()
251+
from zenml.execution.pipeline.dynamic.utils import (
252+
load_step_run_outputs,
253+
)
254+
255+
result = load_step_run_outputs(step_run.id)
256+
155257
if isinstance(result, OutputArtifact):
156258
return result
157259
elif isinstance(result, tuple):
158260
return result[self._index]
159261
else:
160262
raise RuntimeError(
161-
f"Step {self._invocation_id} returned an invalid output: "
263+
f"Step {self.invocation_id} returned an invalid output: "
162264
f"{result}."
163265
)
164266

@@ -188,23 +290,21 @@ def chunk(self, index: int) -> "OutputArtifact":
188290
return self.result().chunk(index=index)
189291

190292

191-
class StepRunOutputsFuture(BaseStepRunFuture):
293+
class StepFuture(BaseStepFuture):
192294
"""Future for a step run output."""
193295

194296
def __init__(
195297
self,
196-
wrapped: Future[StepRunOutputs],
197-
invocation_id: str,
298+
wrapped: Union[_InlineStepFuture, _IsolatedStepFuture],
198299
output_keys: List[str],
199300
) -> None:
200301
"""Initialize the future.
201302
202303
Args:
203304
wrapped: The wrapped future object.
204-
invocation_id: The invocation ID of the step run.
205305
output_keys: The output keys of the step run.
206306
"""
207-
super().__init__(wrapped=wrapped, invocation_id=invocation_id)
307+
super().__init__(wrapped=wrapped)
208308
self._output_keys = output_keys
209309

210310
def get_artifact(self, key: str) -> ArtifactFuture:
@@ -221,31 +321,39 @@ def get_artifact(self, key: str) -> ArtifactFuture:
221321
"""
222322
if key not in self._output_keys:
223323
raise KeyError(
224-
f"Step run {self._invocation_id} does not have an output with "
324+
f"Step run {self.invocation_id} does not have an output with "
225325
f"the name: {key}."
226326
)
227327

228328
return ArtifactFuture(
229329
wrapped=self._wrapped,
230-
invocation_id=self._invocation_id,
231330
index=self._output_keys.index(key),
232331
)
233332

333+
def wait(self) -> None:
334+
"""Wait for the step to finish."""
335+
self._wrapped.result()
336+
234337
def artifacts(self) -> StepRunOutputs:
235338
"""Get the step run output artifacts.
236339
237340
Returns:
238341
The step run output artifacts.
239342
"""
240-
return self._wrapped.result()
343+
return self.result()
241344

242345
def result(self) -> StepRunOutputs:
243346
"""Get the step run outputs this future represents.
244347
245348
Returns:
246349
The step run outputs.
247350
"""
248-
return self._wrapped.result()
351+
from zenml.execution.pipeline.dynamic.utils import (
352+
load_step_run_outputs,
353+
)
354+
355+
step_run = self._wrapped.result()
356+
return load_step_run_outputs(step_run.id)
249357

250358
def load(self, disable_cache: bool = False) -> Any:
251359
"""Get the step run output artifact data.
@@ -297,15 +405,13 @@ def __getitem__(
297405

298406
return ArtifactFuture(
299407
wrapped=self._wrapped,
300-
invocation_id=self._invocation_id,
301408
index=self._output_keys.index(output_key),
302409
)
303410
elif isinstance(key, slice):
304411
output_keys = self._output_keys[key]
305412
return tuple(
306413
ArtifactFuture(
307414
wrapped=self._wrapped,
308-
invocation_id=self._invocation_id,
309415
index=self._output_keys.index(output_key),
310416
)
311417
for output_key in output_keys
@@ -324,13 +430,12 @@ def __iter__(self) -> Any:
324430
"""
325431
if not self._output_keys:
326432
raise ValueError(
327-
f"Step {self._invocation_id} does not return any outputs."
433+
f"Step {self.invocation_id} does not return any outputs."
328434
)
329435

330436
for index in range(len(self._output_keys)):
331437
yield ArtifactFuture(
332438
wrapped=self._wrapped,
333-
invocation_id=self._invocation_id,
334439
index=index,
335440
)
336441

@@ -346,7 +451,7 @@ def __len__(self) -> int:
346451
class MapResultsFuture(BaseFuture):
347452
"""Future that represents the results of a `step.map/product(...)` call."""
348453

349-
def __init__(self, futures: List[StepRunOutputsFuture]) -> None:
454+
def __init__(self, futures: List[StepFuture]) -> None:
350455
"""Initialize the map results future.
351456
352457
Args:
@@ -417,14 +522,14 @@ def map_pipeline():
417522
return tuple(map(list, zip(*self.futures)))
418523

419524
@overload
420-
def __getitem__(self, key: int) -> StepRunOutputsFuture: ...
525+
def __getitem__(self, key: int) -> StepFuture: ...
421526

422527
@overload
423-
def __getitem__(self, key: slice) -> List[StepRunOutputsFuture]: ...
528+
def __getitem__(self, key: slice) -> List[StepFuture]: ...
424529

425530
def __getitem__(
426531
self, key: Union[int, slice]
427-
) -> Union[StepRunOutputsFuture, List[StepRunOutputsFuture]]:
532+
) -> Union[StepFuture, List[StepFuture]]:
428533
"""Get a step run future.
429534
430535
Args:
@@ -435,7 +540,7 @@ def __getitem__(
435540
"""
436541
return self.futures[key]
437542

438-
def __iter__(self) -> Iterator[StepRunOutputsFuture]:
543+
def __iter__(self) -> Iterator[StepFuture]:
439544
"""Iterate over the step run futures.
440545
441546
Yields:
@@ -452,6 +557,4 @@ def __len__(self) -> int:
452557
return len(self.futures)
453558

454559

455-
AnyStepRunFuture = Union[
456-
ArtifactFuture, StepRunOutputsFuture, MapResultsFuture
457-
]
560+
AnyStepFuture = Union[ArtifactFuture, StepFuture, MapResultsFuture]

0 commit comments

Comments
 (0)