We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent aeabbcf commit fd56e84Copy full SHA for fd56e84
ffn/jax/train.py
@@ -340,7 +340,7 @@ def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None:
340
341
def _get_ocp_args(
342
train_iter: DataIterator, restore: bool = True
343
-) -> DataIterator:
+) -> DataIterator | ocp.args.CheckpointArgs:
344
if isinstance(train_iter, tf.data.Iterator):
345
return DatasetArgs(train_iter)
346
0 commit comments