Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,28 @@ def save_pytree(
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> bool:
"""Saves a PyTree checkpoint at the given step.
"""Saves a checkpoint, if dictated by :py:class:`.SaveDecisionPolicy`.

This function behaves similarly to :py:func:`.save_pytree` (see
documentation), but performs additional tasks related to managing a sequence
of checkpoint steps.

It consists roughly of the following steps:
- Check whether a checkpoint should be saved at the given step.
- Check whether a save is already in progress. If so, wait for it to
finish.
- Save to a directory given by `root_directory / <step_format>`.
- Perform garbage collection if necessary.
- Return whether a checkpoint was saved or not.

It is important to note that the `Checkpointer` never allows saving more
than one checkpoint at a time. Depending on the
:py:class:`.SaveDecisionPolicy`, a checkpoint may be saved or skipped at a
given step, but if a save is initiated, as dictated by the policy, then it
will proceed as normal as long as no other save is currently in progress. If
a save is already in progress, the function will block until the previous
save has finished.

Args:
step: The step number to save.
pytree: The PyTree to save.
Expand Down Expand Up @@ -332,11 +342,12 @@ def save_pytree_async(
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> async_types.AsyncResponse[bool]:
"""Saves a PyTree checkpoint asynchronously at the given step.
"""Saves a checkpoint asynchronously, if dictated by :py:class:`.SaveDecisionPolicy`.

See documentation for :py:func:`.save_pytree` for more details. This
function executes in the background, and blocks for as little time as
possible.
See documentation for :py:func:`.save_pytree` for full details. This
function is essentially the same, except that it executes mostly in the
background, and blocks for as little time as possible (primarily to
transfer weights from device to host).

Args:
step: The step number to save.
Expand Down
Loading