44import contexttimer
55
66from transformers import AutoTokenizer , AutoModelForCausalLM
7+ from torch .profiler import ProfilerActivity
78
89from sampling import autoregressive_sampling , speculative_sampling , speculative_sampling_v2
910from globals import Decoder
@@ -33,6 +34,29 @@ def parse_arguments():
3334 return args
3435
3536
37+ def benchmark (fn , print_prefix , use_profiler = True , * args , ** kwargs ):
38+ TEST_TIME = 10
39+ profile_filename = f"./profile_logs/{ print_prefix } "
40+
41+ with contexttimer .Timer () as t :
42+ if use_profiler :
43+ with torch .profiler .profile (
44+ activities = [torch .profiler .ProfilerActivity .CUDA ],
45+ schedule = torch .profiler .schedule (wait = 0 , warmup = 1 , active = 2 , repeat = 1 , skip_first = 0 ),
46+ on_trace_ready = torch .profiler .tensorboard_trace_handler (profile_filename ),
47+ record_shapes = False ,
48+ profile_memory = False ,
49+ # with_stack=True
50+ ) as prof :
51+ for _ in range (TEST_TIME ):
52+ output = fn (* args , ** kwargs )
53+ prof .step ()
54+ else :
55+ for _ in range (TEST_TIME ):
56+ output = fn (* args , ** kwargs )
57+
58+ print (f"\n [benchmark] { print_prefix } , tokens/sec: { len (output [0 ]) / t .elapsed / TEST_TIME } , { t .elapsed / TEST_TIME } sec generates { len (output [0 ])} tokens" )
59+
3660def generate (input_text , approx_model_name , target_model_name , num_tokens = 40 , random_seed = None , verbose = False , use_benchmark = True ):
3761 # NOTE() approx_model_name and target_model_name should use the same tokenizer!
3862
@@ -59,23 +83,17 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
5983
6084 TEST_TIME = 10
6185 if use_benchmark :
62- with contexttimer .Timer () as t :
63- for _ in range (TEST_TIME ):
64- output = autoregressive_sampling (input_ids , large_model , num_tokens , top_k = top_k , top_p = top_p )
65- print (f"\n [benchmark] large (target) model autoregressive_sampling 10 times, tokens/sec: { len (output [0 ]) / t .elapsed / TEST_TIME } , { t .elapsed / TEST_TIME } sec generates { len (output [0 ])} " )
66-
86+ benchmark (autoregressive_sampling , "AS_large" , True ,
87+ input_ids , large_model , num_tokens , top_k = top_k , top_p = top_p )
6788
6889 torch .manual_seed (123 )
6990 output = autoregressive_sampling (input_ids , small_model , num_tokens , top_k = top_k , top_p = top_p )
7091 generated_text = tokenizer .decode (output [0 ], skip_special_tokens = True )
7192 print (f"small (approx) model autoregressive_sampling: { generated_text } " )
7293
7394 if use_benchmark :
74- with contexttimer .Timer () as t :
75- for _ in range (TEST_TIME ):
76- output = autoregressive_sampling (input_ids , small_model , num_tokens , top_k = top_k , top_p = top_p )
77- print (f"\n [benchmark] small (approx) model autoregressive_sampling 10 times, tokens/sec: { len (output [0 ]) / t .elapsed / TEST_TIME } , { t .elapsed / TEST_TIME } sec generates { len (output [0 ])} tokens" )
78-
95+ benchmark (autoregressive_sampling , "AS_small" , True ,
96+ input_ids , small_model , num_tokens , top_k = top_k , top_p = top_p )
7997
8098 torch .manual_seed (123 )
8199 output = speculative_sampling_v2 (input_ids , small_model , large_model , num_tokens , top_k = top_k , top_p = top_p , random_seed = random_seed )
@@ -88,10 +106,8 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
88106 print (f"google's speculative_sampling: { generated_text } " )
89107
90108 if use_benchmark :
91- with contexttimer .Timer () as t :
92- for _ in range (TEST_TIME ):
93- output = speculative_sampling (input_ids , small_model , large_model , num_tokens , top_k = top_k , top_p = top_p , random_seed = random_seed )
94- print (f"\n [benchmark] speculative_sampling 10 times, tokens/sec: { len (output [0 ]) / t .elapsed / TEST_TIME } , { t .elapsed / TEST_TIME } sec generates { len (output [0 ])} tokens" )
109+ benchmark (speculative_sampling , "SP" , True ,
110+ input_ids , small_model , large_model , max_len = num_tokens , top_k = top_k , top_p = top_p , random_seed = random_seed )
95111
96112if __name__ == "__main__" :
97113 args = parse_arguments ()
0 commit comments