Skip to content

Commit 1da363e

Browse files
authored
add share_gpt benchmarking results
2 parents 9da1033 + 76117cf commit 1da363e

File tree

8 files changed

+171
-35
lines changed

8 files changed

+171
-35
lines changed

README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,10 @@ The speculative sampling is proposed by Google and Deepmind independently. So I
1717
You need prepare a pair of models using the same embedding and vocabulary. The approximation model should be smaller than the target model. Here are some
1818
tested model pairs.
1919

20-
<center>
21-
22-
| Approx Model | Target Model |
23-
|--------------|--------------|
24-
| [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) | [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) |
25-
| [TinyLlama-1.1B](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) | llama-7b |
2620

2721
</center>
2822

29-
In the sample, I use [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model.
23+
In the sample, we demostrate [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model.
3024

3125
```bash
3226
python main.py \
@@ -35,10 +29,21 @@ python main.py \
3529
--approx_model_name bigscience/bloom-560m
3630
```
3731

38-
You can also use `--v` args to see a token is generated by which model.
32+
You can also use `-v` args to see a token is generated by which model.
3933

4034
![example image](./imgs/sps.jpg "console output")
4135

36+
I recommand you to use llama2-7B and llama2-70B as the approximation and target model respectively. I did observe speedup on this case as shown in the following.
37+
Note the choice of approx model and target model are essential for the speedup. The speedup will not be observed in the following cases:
38+
If the models are both small ones, the speedup will not be observed since the speed differences are not significant.
39+
If the model size difference is too large, more rejection and resampling will occure.
40+
Also the sampling logic is not efficient enough. I noticed substantial overhead is on Softmax and Layernorm. I will try to optimize it in the future.
41+
Do not histant to open an idea on performance improvements.
42+
43+
| | llama2-7b | llama2-70b | Speculative |
44+
|--------------|:--------------:|:--------------:|:--------------:|
45+
| speed(tokens/sec) | 1084.86 | 329.83 | 427.02 |
46+
4247
### Serving
4348
Start an inference server.
4449
```bash

benchmark.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
2+
import torch
3+
import argparse
4+
import contexttimer
5+
from colorama import Fore, Style
6+
from transformers import AutoTokenizer, AutoModelForCausalLM
7+
8+
from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
9+
from globals import Decoder
10+
import json
11+
from tqdm import tqdm
12+
13+
# my local models
14+
MODELZOO = {
15+
# llama-1
16+
# https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b
17+
"llama1b": "/share_nfs/fangjiarui/root/code/hf_models/TinyLlama-1.1B-step-50K-105b",
18+
"llama7b": "/share_nfs/tianzhi/code/llama-7b",
19+
"llama30b": "/share_nfs/fangjiarui/root/code/hf_models/llama-30b-hf",
20+
"llama2-7b" : "/share_nfs/fangjiarui/root/code/hf_models/llama-2-7b-hf",
21+
"llama2-70b" : "/share_nfs/fangjiarui/root/code/hf_models/llama-2-70b-hf",
22+
"bloom-560m": "/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
23+
"bloom7b": "/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1",
24+
"baichuan-7b": "/share_nfs/duanqiyuan/models/source_models/hf/baichuan-7B",
25+
"baichuan-13b": "/share_nfs/duanqiyuan/models/source_models/hf/Baichuan-13B-Base",
26+
}
27+
28+
def parse_arguments():
29+
parser = argparse.ArgumentParser(description='args for main.py')
30+
31+
parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".")
32+
parser.add_argument('--approx_model_name', type=str, default=MODELZOO["llama2-7b"])
33+
parser.add_argument('--target_model_name', type=str, default=MODELZOO["llama2-70b"])
34+
parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode')
35+
parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed, which can makes the result reproducible')
36+
parser.add_argument('--benchmark', '-b', action='store_true', default=False, help='show benchmark results.')
37+
parser.add_argument('--profiling', '-p', action='store_true', default=False, help='collect torch profiler results.')
38+
parser.add_argument('--max_tokens', '-M', type=int, default=20, help='max token number generated.')
39+
parser.add_argument('--gamma', '-g', type=int, default=4, help='guess time.')
40+
args = parser.parse_args()
41+
return args
42+
43+
44+
def benchmark(fn, info, *args, **kwargs):
45+
46+
test_sample_num = 5
47+
with contexttimer.Timer() as t:
48+
total_tokens = 0
49+
with open('/share_nfs/fangjiarui/root/code/datasets/share_gpt.jsonl', 'r') as file:
50+
# add tqdm
51+
52+
with tqdm(total=test_sample_num, desc=f"{info} benchmarking") as pbar:
53+
for line in file.readlines():
54+
data = json.loads(line)
55+
for obj in data:
56+
content = obj["content"]
57+
# print("content", content)
58+
input_ids = Decoder().encode(content, return_tensors='pt').to('cuda')
59+
if len(input_ids[0]) > 2048 :
60+
continue
61+
output_ids = fn(input_ids, *args, **kwargs)
62+
generated_text = Decoder().decode(output_ids)
63+
# print("generated_text", generated_text)
64+
total_tokens += (len(generated_text) - len(input_ids))
65+
test_sample_num -= 1
66+
if test_sample_num < 0:
67+
break
68+
69+
pbar.update(1)
70+
71+
print(f"\n [benchmark] {info} tokens/sec: {total_tokens / t.elapsed}, {t.elapsed} sec generates {total_tokens} tokens")
72+
73+
def generate(input_text, approx_model_name, target_model_name, num_tokens=100, gamma = 4,
74+
random_seed = None):
75+
# NOTE() approx_model_name and target_model_name should use the same tokenizer!
76+
77+
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
78+
79+
tokenizer = AutoTokenizer.from_pretrained(approx_model_name, trust_remote_code=True)
80+
81+
Decoder().set_tokenizer(tokenizer)
82+
83+
print(f"begin loading models: \n {approx_model_name} \n {target_model_name}")
84+
small_model = AutoModelForCausalLM.from_pretrained(approx_model_name,
85+
torch_dtype=torch.float16,
86+
device_map="auto",
87+
trust_remote_code=True)
88+
large_model = AutoModelForCausalLM.from_pretrained(target_model_name,
89+
torch_dtype=torch.float16,
90+
device_map="auto",
91+
trust_remote_code=True)
92+
print("finish loading models")
93+
94+
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(torch_device)
95+
96+
top_k = 20
97+
top_p = 0.9
98+
99+
torch.manual_seed(123)
100+
benchmark(autoregressive_sampling, "AS_large", large_model, num_tokens, top_k = top_k, top_p=top_p)
101+
102+
torch.manual_seed(123)
103+
benchmark(autoregressive_sampling, "AS_small", small_model, num_tokens, top_k = top_k, top_p=top_p)
104+
105+
torch.manual_seed(123)
106+
benchmark(speculative_sampling, "SP", small_model, large_model, max_len = num_tokens, gamma = gamma, top_k = top_k, top_p=top_p, random_seed = random_seed)
107+
108+
if __name__ == "__main__":
109+
args = parse_arguments()
110+
111+
generate(args.input, args.approx_model_name, args.target_model_name, num_tokens=args.max_tokens, gamma=args.gamma)

globals.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ def __init__(self):
1515
def set_tokenizer(self, tokenizer):
1616
self.tokenizer = tokenizer
1717

18+
def encode(self, s: str, return_tensors='pt') -> torch.Tensor:
19+
return self.tokenizer.encode(s, return_tensors=return_tensors)
20+
1821
def decode(self, t: torch.Tensor) -> str:
1922
return self.tokenizer.decode(t[0], skip_special_tokens=True)

main.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
import torch
33
import argparse
44
import contexttimer
5-
5+
from colorama import Fore, Style
66
from transformers import AutoTokenizer, AutoModelForCausalLM
7-
from torch.profiler import ProfilerActivity
87

98
from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
109
from globals import Decoder
1110

11+
12+
13+
1214
# my local models
1315
MODELZOO = {
16+
# llama-1
1417
# https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b
1518
"llama1b": "/share_nfs/fangjiarui/root/code/hf_models/TinyLlama-1.1B-step-50K-105b",
1619
"llama7b": "/share_nfs/tianzhi/code/llama-7b",
17-
# https://huggingface.co/huggyllama/llama-13b
18-
"llama13b": None,
20+
"llama30b": "/share_nfs/fangjiarui/root/code/hf_models/llama-30b-hf",
21+
"llama2-7b" : "/share_nfs/fangjiarui/root/code/hf_models/llama-2-7b-hf",
22+
"llama2-70b" : "/share_nfs/fangjiarui/root/code/hf_models/llama-2-70b-hf",
1923
"bloom-560m": "/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
2024
"bloom7b": "/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1",
2125
"baichuan-7b": "/share_nfs/duanqiyuan/models/source_models/hf/baichuan-7B",
@@ -25,15 +29,22 @@
2529
def parse_arguments():
2630
parser = argparse.ArgumentParser(description='args for main.py')
2731

28-
parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".")
29-
parser.add_argument('--approx_model_name', type=str, default=MODELZOO["bloom-560m"])
30-
parser.add_argument('--target_model_name', type=str, default=MODELZOO["bloom7b"])
32+
parser.add_argument('--input', type=str, default="Any recommendations for my holidays in Abu Dhabi?")
33+
parser.add_argument('--approx_model_name', type=str, default=MODELZOO["llama2-7b"])
34+
parser.add_argument('--target_model_name', type=str, default=MODELZOO["llama2-70b"])
3135
parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode')
32-
parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed')
36+
parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed, which can makes the result reproducible')
37+
parser.add_argument('--benchmark', '-b', action='store_true', default=False, help='show benchmark results.')
38+
parser.add_argument('--profiling', '-p', action='store_true', default=False, help='collect torch profiler results.')
39+
parser.add_argument('--max_tokens', '-M', type=int, default=20, help='max token number generated.')
40+
parser.add_argument('--gamma', '-g', type=int, default=4, help='guess time.')
3341
args = parser.parse_args()
3442
return args
3543

3644

45+
def color_print(text):
46+
print(Fore.RED + text + Style.RESET_ALL)
47+
3748
def benchmark(fn, print_prefix, use_profiler=True, *args, **kwargs):
3849
TEST_TIME = 10
3950
profile_filename = f"./profile_logs/{print_prefix}"
@@ -57,7 +68,8 @@ def benchmark(fn, print_prefix, use_profiler=True, *args, **kwargs):
5768

5869
print(f"\n [benchmark] {print_prefix}, tokens/sec: {len(output[0]) / t.elapsed / TEST_TIME}, {t.elapsed / TEST_TIME} sec generates {len(output[0])} tokens")
5970

60-
def generate(input_text, approx_model_name, target_model_name, num_tokens=40, random_seed = None, verbose = False, use_benchmark = False):
71+
def generate(input_text, approx_model_name, target_model_name, num_tokens=20, gamma = 4,
72+
random_seed = None, verbose = False, use_benchmark = False, use_profiling = False):
6173
# NOTE() approx_model_name and target_model_name should use the same tokenizer!
6274

6375
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -85,37 +97,37 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
8597
torch.manual_seed(123)
8698
output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
8799
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
88-
print(f"large (target) model autoregressive_sampling: {generated_text}")
100+
color_print(f"large (target) model autoregressive_sampling: {generated_text}")
89101

90-
TEST_TIME = 10
91102
if use_benchmark:
92-
benchmark(autoregressive_sampling, "AS_large", True,
103+
benchmark(autoregressive_sampling, "AS_large", use_profiling,
93104
input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
94105

95106
torch.manual_seed(123)
96107
output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
97108
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
98-
print(f"small (approx) model autoregressive_sampling: {generated_text}")
109+
color_print(f"small (approx) model autoregressive_sampling: {generated_text}")
99110

100111
if use_benchmark:
101-
benchmark(autoregressive_sampling, "AS_small", True,
112+
benchmark(autoregressive_sampling, "AS_small", use_profiling,
102113
input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
103114

104115
torch.manual_seed(123)
105116
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
106117
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
107-
print(f"deepmind's speculative_sampling: {generated_text}")
118+
color_print(f"deepmind's speculative_sampling: {generated_text}")
108119

109120
torch.manual_seed(123)
110-
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose)
121+
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, gamma = gamma, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose)
111122
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
112-
print(f"google's speculative_sampling: {generated_text}")
123+
color_print(f"google's speculative_sampling: {generated_text}")
113124

114125
if use_benchmark:
115-
benchmark(speculative_sampling, "SP", True,
116-
input_ids, small_model, large_model, max_len = num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
126+
benchmark(speculative_sampling, "SP", use_profiling,
127+
input_ids, small_model, large_model, max_len = num_tokens, gamma = gamma, top_k = top_k, top_p=top_p, random_seed = random_seed)
117128

118129
if __name__ == "__main__":
119130
args = parse_arguments()
120131

121-
generate(args.input, args.approx_model_name, args.target_model_name, random_seed = args.seed, verbose=args.verbose)
132+
generate(args.input, args.approx_model_name, args.target_model_name, num_tokens=args.max_tokens, gamma=args.gamma,
133+
random_seed = args.seed, verbose=args.verbose, use_benchmark = args.benchmark)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ transformers==4.33.2
22
torch==2.0.1
33
contexttimer
44
flask
5-
transformers_stream_generator
5+
transformers_stream_generator
6+
colorama

sampling/kvcache_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def _forward_with_kvcache(self, input_ids : torch.Tensor, use_debug = True) -> t
3030
self._prob_history = outputs.logits
3131
for i in range(self._prob_history.shape[-2]):
3232
self._prob_history[:, i, :] = norm_logits(self._prob_history[:, i, :], self._temperature, self._top_k, self._top_p)
33-
# self._prob_history[:, -1, :] = norm_logits(self._prob_history[:, -1, :], self._temperature, self._top_k, self._top_p)
3433
self._past_key_values = outputs.past_key_values
3534
last_q = self._prob_history[:, -1, :]
3635
else:

sampling/speculative_sampling.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
4141
approx_model_cache = KVCacheModel(approx_model, temperature, top_k, top_p)
4242
target_model_cache = KVCacheModel(target_model, temperature, top_k, top_p)
4343

44+
resample_count = 0
45+
target_sample_count = 0
46+
accepted_count = 0
47+
4448
while prefix.shape[1] < T:
4549
# q = M_q[prefix + x_0, x_1, .., x_(gamma-2)]
4650
prefix_len = prefix.shape[1]
@@ -64,6 +68,8 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
6468

6569
if verbose:
6670
print(f"approx guess accepted {j[0]}: \033[31m{Decoder().decode(torch.tensor([j]))}\033[0m")
71+
72+
accepted_count += 1
6773

6874
# print(f"n : {n}, i : {i}, prefix_len + gamma - 1: {prefix_len + gamma - 1}")
6975
assert n >= prefix_len - 1, f"n {n}, prefix_len {prefix_len}"
@@ -78,20 +84,22 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
7884
t = sample(max_fn(target_model_cache._prob_history[:, n, :] - approx_model_cache._prob_history[:, n, :]))
7985
if verbose:
8086
print(f"target resamples at position {n}: \033[34m{Decoder().decode(t)}\033[0m")
81-
87+
resample_count += 1
8288
target_model_cache.rollback(n+1)
8389
else:
8490
# all approx model decoding accepted
8591
assert n == target_model_cache._prob_history.shape[1] - 1
8692
t = sample(target_model_cache._prob_history[:, -1, :])
8793
if verbose:
8894
print(f"target samples {n}: \033[35m{Decoder().decode(t)}\033[0m")
95+
target_sample_count += 1
8996
target_model_cache.rollback(n+2)
9097

9198

9299
prefix = torch.cat((prefix, t), dim=1)
93100

94-
101+
if verbose:
102+
print(f"generated tokens numbers {prefix.shape[-1] - seq_len}, accepted_count {accepted_count}, target_sample_count {target_sample_count}, resample_count {resample_count}")
95103
return prefix
96104

97105

serving.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def predict():
5555
return jsonify(result)
5656

5757
if __name__ == '__main__':
58-
# Load the model
59-
# load_model("/share_nfs/fangjiarui/root/code/hf_models/bloom-560m")
60-
6158
GLOBAL_SERVER = Server(approx_model_name="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
6259
target_model_name="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1")
6360
# Start the Flask service

0 commit comments

Comments
 (0)