Skip to content

Commit 3b6f616

Browse files
committed
add profiler
1 parent 349b128 commit 3b6f616

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

main.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import contexttimer
55

66
from transformers import AutoTokenizer, AutoModelForCausalLM
7+
from torch.profiler import ProfilerActivity
78

89
from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
910
from globals import Decoder
@@ -33,6 +34,29 @@ def parse_arguments():
3334
return args
3435

3536

37+
def benchmark(fn, print_prefix, use_profiler=True, *args, **kwargs):
38+
TEST_TIME = 10
39+
profile_filename = f"./profile_logs/{print_prefix}"
40+
41+
with contexttimer.Timer() as t:
42+
if use_profiler:
43+
with torch.profiler.profile(
44+
activities=[torch.profiler.ProfilerActivity.CUDA],
45+
schedule=torch.profiler.schedule(wait=0, warmup=1, active=2, repeat=1, skip_first=0),
46+
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_filename),
47+
record_shapes=False,
48+
profile_memory=False,
49+
# with_stack=True
50+
) as prof:
51+
for _ in range(TEST_TIME):
52+
output = fn(*args, **kwargs)
53+
prof.step()
54+
else:
55+
for _ in range(TEST_TIME):
56+
output = fn(*args, **kwargs)
57+
58+
print(f"\n [benchmark] {print_prefix}, tokens/sec: {len(output[0]) / t.elapsed / TEST_TIME}, {t.elapsed / TEST_TIME} sec generates {len(output[0])} tokens")
59+
3660
def generate(input_text, approx_model_name, target_model_name, num_tokens=40, random_seed = None, verbose = False, use_benchmark = True):
3761
# NOTE() approx_model_name and target_model_name should use the same tokenizer!
3862

@@ -59,23 +83,17 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
5983

6084
TEST_TIME = 10
6185
if use_benchmark:
62-
with contexttimer.Timer() as t:
63-
for _ in range(TEST_TIME):
64-
output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
65-
print(f"\n[benchmark] large (target) model autoregressive_sampling 10 times, tokens/sec: {len(output[0]) / t.elapsed / TEST_TIME}, {t.elapsed / TEST_TIME} sec generates {len(output[0])}")
66-
86+
benchmark(autoregressive_sampling, "AS_large", True,
87+
input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
6788

6889
torch.manual_seed(123)
6990
output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
7091
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
7192
print(f"small (approx) model autoregressive_sampling: {generated_text}")
7293

7394
if use_benchmark:
74-
with contexttimer.Timer() as t:
75-
for _ in range(TEST_TIME):
76-
output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
77-
print(f"\n[benchmark] small (approx) model autoregressive_sampling 10 times, tokens/sec: {len(output[0]) / t.elapsed / TEST_TIME}, {t.elapsed / TEST_TIME} sec generates {len(output[0])} tokens")
78-
95+
benchmark(autoregressive_sampling, "AS_small", True,
96+
input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
7997

8098
torch.manual_seed(123)
8199
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
@@ -88,10 +106,8 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
88106
print(f"google's speculative_sampling: {generated_text}")
89107

90108
if use_benchmark:
91-
with contexttimer.Timer() as t:
92-
for _ in range(TEST_TIME):
93-
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
94-
print(f"\n[benchmark] speculative_sampling 10 times, tokens/sec: {len(output[0]) / t.elapsed / TEST_TIME}, {t.elapsed / TEST_TIME} sec generates {len(output[0])} tokens")
109+
benchmark(speculative_sampling, "SP", True,
110+
input_ids, small_model, large_model, max_len = num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
95111

96112
if __name__ == "__main__":
97113
args = parse_arguments()

0 commit comments

Comments
 (0)