Skip to content

Commit 5bb822a

Browse files
hejiang0116Orbax Authors
authored andcommitted
Fix argument handling in TfDataProcessor for single-argument tf.function.
PiperOrigin-RevId: 865485895
1 parent 67b9025 commit 5bb822a

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

export/orbax/export/data_processors/tf_data_processor_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,39 @@ def test_prepare_with_shlo_bf16_inputs(self):
214214
obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16, name='x'),
215215
)
216216

217+
def test_prepare_with_single_list_argument(self):
218+
219+
def add_list(inputs):
220+
return inputs[0] + inputs[1]
221+
222+
processor = tf_data_processor.TfDataProcessor(add_list, name='add')
223+
processor.prepare(
224+
(
225+
tf.TensorSpec([None, 3], tf.float32),
226+
tf.TensorSpec([None, 3], tf.float32),
227+
),
228+
)
229+
230+
self.assertIsNotNone(processor.concrete_function)
231+
self.assertIsNotNone(processor.obm_function)
232+
self.assertEqual(
233+
processor.input_signature[0][0],
234+
(
235+
obm.ShloTensorSpec(
236+
shape=(None, 3), dtype=obm.ShloDType.f32, name='inputs_0'
237+
),
238+
obm.ShloTensorSpec(
239+
shape=(None, 3), dtype=obm.ShloDType.f32, name='inputs_1'
240+
),
241+
),
242+
)
243+
self.assertEqual(
244+
processor.output_signature,
245+
obm.ShloTensorSpec(
246+
shape=(None, 3), dtype=obm.ShloDType.f32, name='output_0'
247+
),
248+
)
249+
217250

218251
if __name__ == '__main__':
219252
googletest.main()

0 commit comments

Comments
 (0)