@@ -1924,7 +1924,15 @@ def _inner_training_loop(
19241924
19251925 step = - 1
19261926
1927+ # Data loading timing for global_step
1928+ _data_load_time_for_global_step = 0.0
1929+ _data_load_start_time = time .time ()
1930+
19271931 for step , inputs in enumerate (epoch_iterator ):
1932+ # Record data loading time for this iteration
1933+ _data_load_end_time = time .time ()
1934+ _data_load_time_for_global_step += _data_load_end_time - _data_load_start_time
1935+
19281936 if self .args .profile and step % self .args .gradient_accumulation_steps == 0 :
19291937 perf_utils .switch_profile (
19301938 self .state .global_step ,
@@ -1959,6 +1967,8 @@ def _inner_training_loop(
19591967 if steps_trained_in_current_epoch == 0 :
19601968 self ._load_rng_state (resume_from_checkpoint )
19611969 self .timers and self .timers ("read-data" ).start ()
1970+ # Reset data loading timer for skipped steps
1971+ _data_load_start_time = time .time ()
19621972 continue
19631973 elif steps_trained_progress_bar is not None :
19641974 steps_trained_progress_bar .close ()
@@ -2020,6 +2030,9 @@ def _inner_training_loop(
20202030 break
20212031
20222032 self .timers and self .timers ("read-data" ).start ()
2033+ # Reset data loading timer for skipped data
2034+ _data_load_time_for_global_step = 0.0
2035+ _data_load_start_time = time .time ()
20232036 continue
20242037
20252038 for inputs in inputs_list :
@@ -2189,7 +2202,6 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
21892202 self .callback_handler .on_optimizer_begin (
21902203 args , self .state , self .control , scaler = self .scaler if self .do_grad_scaling else None
21912204 )
2192-
21932205 self .optimizer_step (args , model = model , parameters_list = parameters_list )
21942206
21952207 self .timers and self .timers ("optimizer-step" ).stop ()
@@ -2217,7 +2229,15 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22172229
22182230 self .control = self .callback_handler .on_step_end (args , self .state , self .control )
22192231 self ._maybe_log_save_evaluate (tr_loss , model , epoch , ignore_keys_for_eval , inputs = inputs )
2232+ # Log data loading time for this global_step
2233+ logger .info (
2234+ f"[DataLoad global_step: { self .state .global_step } ] "
2235+ f"data_load_time: { _data_load_time_for_global_step * 1000 :.2f} ms "
2236+ f"(accumulated over { args .gradient_accumulation_steps } micro-batches)"
2237+ )
22202238 self ._print_timer ()
2239+ # Reset data loading timer for next global_step
2240+ _data_load_time_for_global_step = 0.0
22212241 step_control = 0
22222242 else :
22232243 self .control = self .callback_handler .on_substep_end (args , self .state , self .control )
@@ -2229,6 +2249,9 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22292249 if self .args .ignore_data_skip :
22302250 self .timers and self .timers ("read-data" ).start ()
22312251
2252+ # Reset start time for next iteration's data loading measurement
2253+ _data_load_start_time = time .time ()
2254+
22322255 if step < 0 :
22332256 logger .warning (
22342257 f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
0 commit comments