Skip to content

Commit 0a53b4a

Browse files
lyglstOrbax Authors
authored andcommitted
Add flags for PyTorch & DCP support in the Orbax checkpoint benchmark launcher.
PiperOrigin-RevId: 865621844
1 parent 1f4924e commit 0a53b4a

File tree

8 files changed

+588
-20
lines changed

8 files changed

+588
-20
lines changed

checkpoint/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ trigger creation if it has not already been started.
2727
- Add new `OrbaxV0Layout` that will handle specific v0 checkpoint format logic.
2828
- Add sharding fallback for target tree leaves in `StandardCheckpointHandler`
2929
restore, removing sharding/topology warnings.
30+
- Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite.
3031

3132
## [0.11.31] - 2025-12-11
3233

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/checkpoint_generation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,27 @@ def _partition_axis_name(offset: int) -> str:
122122

123123

124124

125+
def is_safetensor_checkpoint(path: str|epath.Path) -> bool:
126+
"""Checks if the checkpoint is a SafeTensor checkpoint."""
127+
path = epath.Path(path)
128+
for f in path.iterdir():
129+
if f.is_file() and 'safetensors' in f.name:
130+
return True
131+
return False
132+
133+
125134
def load_checkpoint(path: str) -> Any:
126135
"""Loads a PyTree of test checkpoint from a provided path."""
127136
logging.info('Loading checkpoint from path: %s', path)
128137
path = epath.Path(path)
129138

130139

140+
# If the checkpoint is a SafeTensor checkpoint, return the path directly.
141+
# This is because we don't need to load the checkpoint into a PyTree, and can
142+
# directly use the path to load the checkpoint in the benchmark test.
143+
if is_safetensor_checkpoint(path):
144+
return path
145+
131146
use_ocdbt = type_handlers.is_ocdbt_checkpoint(path)
132147
with checkpointer.Checkpointer(
133148
pytree_checkpoint_handler.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import hashlib
2121
import itertools
2222
import sys
23+
import threading
2324
from typing import Any, Callable
2425

2526
from absl import logging
@@ -33,6 +34,29 @@
3334
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
3435

3536

37+
def _sync_global_processes(
38+
name: str,
39+
):
40+
"""Syncs global processes using torch.distributed if available, else multihost."""
41+
try:
42+
import torch.distributed as dist # pylint: disable=g-import-not-at-top
43+
44+
if dist.is_initialized():
45+
logging.vlog(
46+
1,
47+
"[process=%s][thread=%s] sync_global_processes with torch"
48+
" barrier: %s",
49+
dist.get_rank(),
50+
threading.current_thread().name,
51+
name,
52+
)
53+
dist.barrier()
54+
return
55+
except ImportError:
56+
pass
57+
multihost.sync_global_processes(name)
58+
59+
3660
@dataclasses.dataclass(frozen=True)
3761
class BenchmarkOptions:
3862
"""Base class for benchmark generator options."""
@@ -148,13 +172,13 @@ def run(self, repeat_index: int | None = None) -> TestResult:
148172
name += f"_repeat_{repeat_index}"
149173
logging.info(
150174
"[process_id=%s] Setting up test: %s",
151-
multihost.process_index(),
175+
metric_lib.get_process_index(),
152176
name,
153177
)
154178

155179
benchmark_metrics = metric_lib.Metrics(name=f"{name} Internal")
156180
with benchmark_metrics.measure("sync_global_processes:benchmark:run"):
157-
multihost.sync_global_processes("benchmark:run")
181+
_sync_global_processes("benchmark:run")
158182

159183
path = directory_setup.setup_test_directory(
160184
self.name, self.output_dir, repeat_index
@@ -163,7 +187,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
163187
with benchmark_metrics.measure(
164188
"sync_global_processes:benchmark:setup_test_directory"
165189
):
166-
multihost.sync_global_processes("benchmark:setup_test_directory")
190+
_sync_global_processes("benchmark:setup_test_directory")
167191

168192
if self.checkpoint_config.path is None:
169193
data = checkpoint_generation.generate_checkpoint(
@@ -175,7 +199,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
175199
with benchmark_metrics.measure(
176200
"sync_global_processes:benchmark:setup_pytree"
177201
):
178-
multihost.sync_global_processes("benchmark:setup_pytree")
202+
_sync_global_processes("benchmark:setup_pytree")
179203

180204
context = TestContext(
181205
pytree=data,
@@ -191,7 +215,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
191215

192216
logging.info(
193217
"[process_id=%s] Executing test function: %s",
194-
multihost.process_index(),
218+
metric_lib.get_process_index(),
195219
name,
196220
)
197221
try:
@@ -201,13 +225,13 @@ def run(self, repeat_index: int | None = None) -> TestResult:
201225
# execution is recorded in the TestResult.
202226
if sys.version_info >= (3, 11):
203227
e.add_note(
204-
f"[process_id={multihost.process_index()}],"
228+
f"[process_id={metric_lib.get_process_index()}],"
205229
f" {test_context_summary[:100]}"
206230
)
207231
logging.error(
208232
"[process_id=%s] Test function '%s' context: %s, raised an"
209233
" exception: %s",
210-
multihost.process_index(),
234+
metric_lib.get_process_index(),
211235
name,
212236
test_context_summary[:100],
213237
e,
@@ -221,7 +245,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
221245

222246
logging.info(
223247
"[process_id=%s] Test finished: %s",
224-
multihost.process_index(),
248+
metric_lib.get_process_index(),
225249
name,
226250
)
227251

@@ -304,13 +328,13 @@ def _get_options_product(self) -> Sequence[BenchmarkOptions]:
304328
option_instances.append(option_instance)
305329
logging.info(
306330
"[process_id=%s] Generating valid option combination: %s",
307-
multihost.process_index(),
331+
metric_lib.get_process_index(),
308332
option_instance,
309333
)
310334
else:
311335
logging.info(
312336
"[process_id=%s] Skipping invalid option combination: %s",
313-
multihost.process_index(),
337+
metric_lib.get_process_index(),
314338
option_instance,
315339
)
316340
return option_instances
@@ -458,5 +482,5 @@ def run(self) -> Sequence[TestResult]:
458482
)
459483

460484
logging.info(self._suite_metrics.generate_report())
461-
multihost.sync_global_processes("test_suite:run_end")
485+
_sync_global_processes("test_suite:run_end")
462486
return all_results

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/directory_setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from absl import logging
1818
from etils import epath
19-
import jax
19+
from orbax.checkpoint._src.testing.benchmarks.core import metric
2020

2121

2222
def setup_test_directory(
@@ -39,7 +39,7 @@ def setup_test_directory(
3939
if repeat_index is not None:
4040
path = path / f"repeat_{repeat_index}"
4141
logging.info("Setting up test directory at: %s", path)
42-
if jax.process_index() == 0:
42+
if metric.get_process_index() == 0:
4343
if path.exists():
4444
logging.warning("Test directory %s already exists. Deleting it.", path)
4545
path.rmtree()

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@
3535
import tensorstore as ts
3636

3737

38+
def get_process_index():
39+
"""Returns process index from torch.distributed if available, else from multihost."""
40+
try:
41+
import torch.distributed as dist # pylint: disable=g-import-not-at-top
42+
43+
if dist.is_initialized():
44+
return dist.get_rank()
45+
except ImportError:
46+
pass
47+
return multihost.process_index()
48+
49+
3850
class BaseMetric:
3951
"""Base class for a metric type."""
4052

@@ -47,7 +59,7 @@ def start(self):
4759
self._start_time = time.perf_counter()
4860
logging.info(
4961
"[process_id=%s] Starting metric: '%s'...",
50-
multihost.process_index(),
62+
get_process_index(),
5163
self.name,
5264
)
5365

@@ -56,7 +68,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
5668
duration = time.perf_counter() - self._start_time
5769
logging.info(
5870
"[process_id=%s] Finished metric: '%s' (took %.4fs)",
59-
multihost.process_index(),
71+
get_process_index(),
6072
self.name,
6173
duration,
6274
)
@@ -168,7 +180,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
168180

169181
self._log_tracemalloc_snapshot_diff(
170182
self.name,
171-
multihost.process_index(),
183+
get_process_index(),
172184
self._start_snapshot,
173185
end_snapshot,
174186
top_n=15,
@@ -285,7 +297,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
285297
diff = self._diff_metrics(self._start_metrics, end_metrics)
286298
logging.info(
287299
"[process_id=%s] Finished metric: %s, num_diffs=%d",
288-
multihost.process_index(),
300+
get_process_index(),
289301
self.name,
290302
len(diff),
291303
)
@@ -423,12 +435,12 @@ def report(self):
423435
"""Logs a formatted report of all collected metrics."""
424436
report_lines = []
425437
report_lines.append(
426-
f"---[process_id={multihost.process_index()}] {self.name} Metrics"
438+
f"---[process_id={get_process_index()}] {self.name} Metrics"
427439
" Report ---"
428440
)
429441
if not self.results:
430442
report_lines.append(
431-
f"[process_id={multihost.process_index()}] No metrics recorded."
443+
f"[process_id={get_process_index()}] No metrics recorded."
432444
)
433445
else:
434446
for name, (value, unit) in sorted(self.results.items()):
@@ -649,7 +661,7 @@ def export_to_tensorboard(self, tensorboard_dir: epath.Path):
649661
"""Exports metrics to TensorBoard."""
650662
logging.info("Writing per-repetition metrics to TensorBoard...")
651663
for benchmark_name, results in self._runs.items():
652-
is_primary_host = multihost.process_index() == 0
664+
is_primary_host = get_process_index() == 0
653665
writer = metric_writers.create_default_writer(
654666
tensorboard_dir,
655667
just_logging=not is_primary_host,

0 commit comments

Comments
 (0)