Skip to content

Commit fd56e84

Browse files
lyglstcopybara-github
authored andcommitted
Fix a type incompatibility issues.
PiperOrigin-RevId: 853779316
1 parent aeabbcf commit fd56e84

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ffn/jax/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None:
340340

341341
def _get_ocp_args(
342342
train_iter: DataIterator, restore: bool = True
343-
) -> DataIterator:
343+
) -> DataIterator | ocp.args.CheckpointArgs:
344344
if isinstance(train_iter, tf.data.Iterator):
345345
return DatasetArgs(train_iter)
346346

0 commit comments

Comments
 (0)