From 1f4924eec2bda214f34500cbd682546cbdd18e32 Mon Sep 17 00:00:00 2001 From: He Jiang Date: Thu, 5 Feb 2026 13:28:28 -0800 Subject: [PATCH] Fix argument handling in TfDataProcessor for single-argument tf.function. PiperOrigin-RevId: 866088445 --- .../data_processors/tf_data_processor_test.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/export/orbax/export/data_processors/tf_data_processor_test.py b/export/orbax/export/data_processors/tf_data_processor_test.py index aeedd0fae..815d9e463 100644 --- a/export/orbax/export/data_processors/tf_data_processor_test.py +++ b/export/orbax/export/data_processors/tf_data_processor_test.py @@ -214,6 +214,39 @@ def test_prepare_with_shlo_bf16_inputs(self): obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16, name='x'), ) + def test_prepare_with_single_list_argument(self): + + def add_list(inputs): + return inputs[0] + inputs[1] + + processor = tf_data_processor.TfDataProcessor(add_list, name='add') + processor.prepare( + ( + tf.TensorSpec([None, 3], tf.float32), + tf.TensorSpec([None, 3], tf.float32), + ), + ) + + self.assertIsNotNone(processor.concrete_function) + self.assertIsNotNone(processor.obm_function) + self.assertEqual( + processor.input_signature[0][0], + ( + obm.ShloTensorSpec( + shape=(None, 3), dtype=obm.ShloDType.f32, name='inputs_0' + ), + obm.ShloTensorSpec( + shape=(None, 3), dtype=obm.ShloDType.f32, name='inputs_1' + ), + ), + ) + self.assertEqual( + processor.output_signature, + obm.ShloTensorSpec( + shape=(None, 3), dtype=obm.ShloDType.f32, name='output_0' + ), + ) + if __name__ == '__main__': googletest.main()