@@ -1944,7 +1944,15 @@ def _inner_training_loop(
19441944
19451945 step = - 1
19461946
1947+ # Data loading timing for global_step
1948+ _data_load_time_for_global_step = 0.0
1949+ _data_load_start_time = time .time ()
1950+
19471951 for step , inputs in enumerate (epoch_iterator ):
1952+ # Record data loading time for this iteration
1953+ _data_load_end_time = time .time ()
1954+ _data_load_time_for_global_step += _data_load_end_time - _data_load_start_time
1955+
19481956 if self .args .profile and step % self .args .gradient_accumulation_steps == 0 :
19491957 perf_utils .switch_profile (
19501958 self .state .global_step ,
@@ -1979,6 +1987,8 @@ def _inner_training_loop(
19791987 if steps_trained_in_current_epoch == 0 :
19801988 self ._load_rng_state (resume_from_checkpoint )
19811989 self .timers and self .timers ("read-data" ).start ()
1990+ # Reset data loading timer for skipped steps
1991+ _data_load_start_time = time .time ()
19821992 continue
19831993 elif steps_trained_progress_bar is not None :
19841994 steps_trained_progress_bar .close ()
@@ -2040,6 +2050,9 @@ def _inner_training_loop(
20402050 break
20412051
20422052 self .timers and self .timers ("read-data" ).start ()
2053+ # Reset data loading timer for skipped data
2054+ _data_load_time_for_global_step = 0.0
2055+ _data_load_start_time = time .time ()
20432056 continue
20442057
20452058 for inputs in inputs_list :
@@ -2209,7 +2222,6 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22092222 self .callback_handler .on_optimizer_begin (
22102223 args , self .state , self .control , scaler = self .scaler if self .do_grad_scaling else None
22112224 )
2212-
22132225 self .optimizer_step (args , model = model , parameters_list = parameters_list )
22142226
22152227 self .timers and self .timers ("optimizer-step" ).stop ()
@@ -2237,7 +2249,15 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22372249
22382250 self .control = self .callback_handler .on_step_end (args , self .state , self .control )
22392251 self ._maybe_log_save_evaluate (tr_loss , model , epoch , ignore_keys_for_eval , inputs = inputs )
2252+ # Log data loading time for this global_step
2253+ logger .info (
2254+ f"[DataLoad global_step: { self .state .global_step } ] "
2255+ f"data_load_time: { _data_load_time_for_global_step * 1000 :.2f} ms "
2256+ f"(accumulated over { args .gradient_accumulation_steps } micro-batches)"
2257+ )
22402258 self ._print_timer ()
2259+ # Reset data loading timer for next global_step
2260+ _data_load_time_for_global_step = 0.0
22412261 step_control = 0
22422262 else :
22432263 self .control = self .callback_handler .on_substep_end (args , self .state , self .control )
@@ -2249,6 +2269,9 @@ def hybrid_parallel_scale_param_grad(paramlist, hcg):
22492269 if self .args .ignore_data_skip :
22502270 self .timers and self .timers ("read-data" ).start ()
22512271
2272+ # Reset start time for next iteration's data loading measurement
2273+ _data_load_start_time = time .time ()
2274+
22522275 if step < 0 :
22532276 logger .warning (
22542277 f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
0 commit comments