Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,17 @@ async def _save(
ckpt_args = checkpointer.construct_checkpoint_args(
self._handler, True, *args, **kwargs
)
commit_ops.extend(
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
)
if isinstance(
self._handler,
async_checkpoint_handler.TemporaryPathAwareAsyncCheckpointHandler,
):
commit_ops.extend(
await self._handler.async_save(tmpdir, args=ckpt_args) or []
)
else:
commit_ops.extend(
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
)
commit_ops, _ = jax.tree.flatten(commit_ops)
commit_ops = [op for op in commit_ops if op is not None]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from etils import epath
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.path import atomicity_types


class AsyncCheckpointHandler(checkpoint_handler.CheckpointHandler):
Expand All @@ -44,3 +45,43 @@ async def async_save(
**kwargs: additional arguments for save.
"""
pass


class TemporaryPathAwareAsyncCheckpointHandler(AsyncCheckpointHandler):
"""Handler interface that receives TemporaryPath for deferred path support.

This interface extends AsyncCheckpointHandler with an async_save method that
accepts either an epath.Path or a TemporaryPath directly, allowing handlers
to work with deferred paths (e.g., TFHub) where the actual path is allocated
asynchronously.

Handlers implementing this interface can:
1. Receive the TemporaryPath before the path is allocated
2. Wait for STEP_DIRECTORY_CREATION signal inside their CommitFuture
3. Call tmpdir.get() after the signal to access the allocated path
"""

@abc.abstractmethod
async def async_save(
self,
directory: epath.Path | atomicity_types.TemporaryPath,
*args,
**kwargs,
) -> Optional[List[future.Future]]:
"""Constructs a save operation with support for TemporaryPath.

This method accepts either an epath.Path or a TemporaryPath. When a
TemporaryPath is passed, handler coroutines should wait for the
STEP_DIRECTORY_CREATION signal before calling directory.get().

Args:
directory: The directory to save to. May be an epath.Path or a
TemporaryPath. For deferred paths, the actual path is not yet allocated.
*args: additional arguments for save.
**kwargs: additional arguments for save.

Returns:
A list of futures that will commit the data when awaited.
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import ocdbt_utils
Expand Down Expand Up @@ -308,7 +309,7 @@ def _format_bytes(bytes_value: Optional[int]) -> str:


class BasePyTreeCheckpointHandler(
async_checkpoint_handler.AsyncCheckpointHandler
async_checkpoint_handler.TemporaryPathAwareAsyncCheckpointHandler
):
"""A CheckpointHandler implementation for any PyTree structure.

Expand Down Expand Up @@ -433,7 +434,7 @@ def get_param_names(self, item: PyTree) -> PyTree:
def _get_param_infos(
self,
item: PyTree,
directory: epath.Path,
directory: atomicity_types.TemporaryPath | epath.Path,
*,
use_ocdbt: bool = True,
use_compression: bool | None = True,
Expand Down Expand Up @@ -581,7 +582,7 @@ def _handle_diffs(keypath, diff):

async def async_save(
self,
directory: epath.Path,
directory: epath.Path | atomicity_types.TemporaryPath,
args: BasePyTreeSaveArgs,
) -> Optional[List[future.Future]]:
"""Saves a PyTree to a given directory.
Expand Down Expand Up @@ -643,8 +644,9 @@ async def async_save(
use_compression=self._use_compression,
use_zarr3=self._use_zarr3,
)

assert all(
leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos)
)

serialize_ops = [] # List of (coros -> List of futures)
Expand All @@ -660,11 +662,11 @@ async def async_save(
# Cannot rely solely on the metadata file existing pre-empted saves may be
# misclassified as partial saves.
partial_save = (
await async_path.exists(directory / PYTREE_METADATA_FILE)
isinstance(directory, epath.Path)
and await async_path.exists(directory / PYTREE_METADATA_FILE)
# TODO: b/428711337 - Use method from v1/_src/partial/path.py instead.
and '.partial_save' in directory.parent.name
)

batch_requests_ready_time = time.time()
if partial_save:
serialize_ops, tree_memory_size, param_infos, save_args = (
Expand Down Expand Up @@ -1190,7 +1192,7 @@ async def _write_metadata_file(
async def _write_metadata_after_commits(
self,
commit_futures: List[future.Future],
checkpoint_dir: epath.Path,
checkpoint_dir: atomicity_types.TemporaryPath | epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
Expand All @@ -1205,6 +1207,9 @@ async def _write_metadata_after_commits(
for commit_future in commit_futures:
await asyncio.to_thread(commit_future.result)

if isinstance(checkpoint_dir, atomicity_types.TemporaryPath):
checkpoint_dir = checkpoint_dir.get()

commit_time = time.time()
# `write_shape` is extracted from ArrayMetadata store saved during
# materialization of commit_futures. Then it is written to the pytree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint.experimental.v1._src.path import types as path_types


class PathResolver:
Expand Down Expand Up @@ -258,7 +259,7 @@ async def read(
)
return None

if isinstance(file_paths, epath.Path):
if isinstance(file_paths, (epath.Path, path_types.PathAwaitingCreation)):
_, result = await self._get_array_metadatas(file_paths)
logging.vlog(
1,
Expand Down
42 changes: 36 additions & 6 deletions checkpoint/orbax/checkpoint/_src/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,24 @@ def get(self) -> epath.Path:
)
return self._tmp_path

def as_posix(self) -> str:
"""Returns the temporary path as a POSIX string.

Convenience method for compatibility with epath.Path interface.
"""
return self.get().as_posix()

def __truediv__(self, other: str) -> epath.Path:
"""Supports the path / operator for joining paths.

Args:
other: The path component to append.

Returns:
The joined path.
"""
return self.get() / other


class ReadOnlyTemporaryPath(atomicity_types.TemporaryPath):
"""A read-only, serializable object providing path properties access.
Expand Down Expand Up @@ -345,9 +363,7 @@ async def validate_final(

async def _shared_validate(class_name: str, path: epath.Path):
if not await async_path.is_dir(path):
raise ValidationError(
f'Expected {class_name} ({path}) to be a directory.'
)
raise ValidationError(f'Expected {class_name} ({path}) to be a directory.')
if not await async_path.exists(path):
raise ValidationError(f'Expected {class_name} ({path}) to exist.')

Expand Down Expand Up @@ -398,9 +414,7 @@ async def validate(
cls,
temporary_path: epath.Path,
):
await validate_atomic_rename_temporary_path(
cls.__name__, temporary_path
)
await validate_atomic_rename_temporary_path(cls.__name__, temporary_path)

@classmethod
async def validate_final(
Expand Down Expand Up @@ -626,6 +640,21 @@ async def finalize(
)


def get_path_or_raise_if_deferred(
path: atomicity_types.TemporaryPath,
) -> epath.Path:
"""Gets the temporary path.

Args:
path: A TemporaryPath.

Returns:
The temporary path.
"""

return path.get()


async def create_all(
paths: Sequence[atomicity_types.TemporaryPath],
*,
Expand Down Expand Up @@ -658,6 +687,7 @@ async def create_all(
timeout=multihost.coordination_timeout(),
processes=active_processes,
)

directory_creation_secs = time.time() - start
jax.monitoring.record_event_duration_secs(
'/jax/orbax/write/directory_creation_secs', directory_creation_secs
Expand Down
23 changes: 23 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/atomicity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import stat

import unittest
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import test_utils


from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_types
Expand Down Expand Up @@ -205,6 +208,26 @@ async def test_finalize_raises(self):
)


class GetPathOrRaiseIfDeferredTest(
parameterized.TestCase,
unittest.IsolatedAsyncioTestCase,
):

def setUp(self):
super().setUp()
self.directory = epath.Path(self.create_tempdir().full_path)

def test_concrete_path_returns_value(self):
tmp = self.directory / 'tmp'
final = self.directory / 'final'
path = AtomicRenameTemporaryPath(
temporary_path=tmp,
final_path=final,
)
result = atomicity.get_path_or_raise_if_deferred(path)
self.assertEqual(result, tmp)



if __name__ == '__main__':
absltest.main()
Loading
Loading