Skip to content

Commit 67c538d

Browse files
authored
supoort_profile_data_load (#3836)
1 parent 5ed20bc commit 67c538d

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

paddleformers/trainer/trainer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1944,7 +1944,15 @@ def _inner_training_loop(
19441944

19451945
step = -1
19461946

1947+
# Data loading timing for global_step
1948+
_data_load_time_for_global_step = 0.0
1949+
_data_load_start_time = time.time()
1950+
19471951
for step, inputs in enumerate(epoch_iterator):
1952+
# Record data loading time for this iteration
1953+
_data_load_end_time = time.time()
1954+
_data_load_time_for_global_step += _data_load_end_time - _data_load_start_time
1955+
19481956
if self.args.profile and step % self.args.gradient_accumulation_steps == 0:
19491957
perf_utils.switch_profile(
19501958
self.state.global_step,
@@ -1979,6 +1987,8 @@ def _inner_training_loop(
19791987
if steps_trained_in_current_epoch == 0:
19801988
self._load_rng_state(resume_from_checkpoint)
19811989
self.timers and self.timers("read-data").start()
1990+
# Reset data loading timer for skipped steps
1991+
_data_load_start_time = time.time()
19821992
continue
19831993
elif steps_trained_progress_bar is not None:
19841994
steps_trained_progress_bar.close()
@@ -2040,6 +2050,9 @@ def _inner_training_loop(
20402050
break
20412051

20422052
self.timers and self.timers("read-data").start()
2053+
# Reset data loading timer for skipped data
2054+
_data_load_time_for_global_step = 0.0
2055+
_data_load_start_time = time.time()
20432056
continue
20442057

20452058
for inputs in inputs_list:
@@ -2209,7 +2222,6 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22092222
self.callback_handler.on_optimizer_begin(
22102223
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
22112224
)
2212-
22132225
self.optimizer_step(args, model=model, parameters_list=parameters_list)
22142226

22152227
self.timers and self.timers("optimizer-step").stop()
@@ -2237,7 +2249,15 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22372249

22382250
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
22392251
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
2252+
# Log data loading time for this global_step
2253+
logger.info(
2254+
f"[DataLoad global_step: {self.state.global_step}] "
2255+
f"data_load_time: {_data_load_time_for_global_step * 1000:.2f} ms "
2256+
f"(accumulated over {args.gradient_accumulation_steps} micro-batches)"
2257+
)
22402258
self._print_timer()
2259+
# Reset data loading timer for next global_step
2260+
_data_load_time_for_global_step = 0.0
22412261
step_control = 0
22422262
else:
22432263
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
@@ -2249,6 +2269,9 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22492269
if self.args.ignore_data_skip:
22502270
self.timers and self.timers("read-data").start()
22512271

2272+
# Reset start time for next iteration's data loading measurement
2273+
_data_load_start_time = time.time()
2274+
22522275
if step < 0:
22532276
logger.warning(
22542277
f"There seems to be not a single sample in your epoch_iterator, stopping training at step"

0 commit comments

Comments
 (0)