Skip to content

Commit 69b27fa

Browse files
committed
supoort_profile_data_load
1 parent d00de6b commit 69b27fa

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

paddleformers/trainer/trainer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1924,7 +1924,15 @@ def _inner_training_loop(
19241924

19251925
step = -1
19261926

1927+
# Data loading timing for global_step
1928+
_data_load_time_for_global_step = 0.0
1929+
_data_load_start_time = time.time()
1930+
19271931
for step, inputs in enumerate(epoch_iterator):
1932+
# Record data loading time for this iteration
1933+
_data_load_end_time = time.time()
1934+
_data_load_time_for_global_step += _data_load_end_time - _data_load_start_time
1935+
19281936
if self.args.profile and step % self.args.gradient_accumulation_steps == 0:
19291937
perf_utils.switch_profile(
19301938
self.state.global_step,
@@ -1959,6 +1967,8 @@ def _inner_training_loop(
19591967
if steps_trained_in_current_epoch == 0:
19601968
self._load_rng_state(resume_from_checkpoint)
19611969
self.timers and self.timers("read-data").start()
1970+
# Reset data loading timer for skipped steps
1971+
_data_load_start_time = time.time()
19621972
continue
19631973
elif steps_trained_progress_bar is not None:
19641974
steps_trained_progress_bar.close()
@@ -2020,6 +2030,9 @@ def _inner_training_loop(
20202030
break
20212031

20222032
self.timers and self.timers("read-data").start()
2033+
# Reset data loading timer for skipped data
2034+
_data_load_time_for_global_step = 0.0
2035+
_data_load_start_time = time.time()
20232036
continue
20242037

20252038
for inputs in inputs_list:
@@ -2189,7 +2202,6 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
21892202
self.callback_handler.on_optimizer_begin(
21902203
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
21912204
)
2192-
21932205
self.optimizer_step(args, model=model, parameters_list=parameters_list)
21942206

21952207
self.timers and self.timers("optimizer-step").stop()
@@ -2217,7 +2229,15 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22172229

22182230
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
22192231
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
2232+
# Log data loading time for this global_step
2233+
logger.info(
2234+
f"[DataLoad global_step: {self.state.global_step}] "
2235+
f"data_load_time: {_data_load_time_for_global_step * 1000:.2f} ms "
2236+
f"(accumulated over {args.gradient_accumulation_steps} micro-batches)"
2237+
)
22202238
self._print_timer()
2239+
# Reset data loading timer for next global_step
2240+
_data_load_time_for_global_step = 0.0
22212241
step_control = 0
22222242
else:
22232243
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
@@ -2229,6 +2249,9 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22292249
if self.args.ignore_data_skip:
22302250
self.timers and self.timers("read-data").start()
22312251

2252+
# Reset start time for next iteration's data loading measurement
2253+
_data_load_start_time = time.time()
2254+
22322255
if step < 0:
22332256
logger.warning(
22342257
f"There seems to be not a single sample in your epoch_iterator, stopping training at step"

0 commit comments

Comments
 (0)