Skip to content

Commit 349b128

Browse files
committed
fix global vars bug
1 parent b8c7d83 commit 349b128

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

globals.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
class Singleton(type):
4+
_instances = {}
5+
6+
def __call__(cls, *args, **kwargs):
7+
if cls not in cls._instances:
8+
cls._instances[cls] = super().__call__(*args, **kwargs)
9+
return cls._instances[cls]
10+
11+
class Decoder(metaclass=Singleton):
12+
def __init__(self):
13+
self.tokenizer = None
14+
15+
def set_tokenizer(self, tokenizer):
16+
self.tokenizer = tokenizer
17+
18+
def decode(self, t: torch.Tensor) -> str:
19+
return self.tokenizer.decode(t[0], skip_special_tokens=True)

main.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@
66
from transformers import AutoTokenizer, AutoModelForCausalLM
77

88
from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
9-
10-
class Decoder:
11-
def __init__(self, tokenizer) -> None:
12-
self.tokenizer = tokenizer
13-
14-
def decode(self, t : torch.Tensor) -> str:
15-
# assert t.dim == 2, "t must be 2d tensor"
16-
return self.tokenizer.decode(t[0], skip_special_tokens=True)
17-
18-
DECODER : Decoder = None
9+
from globals import Decoder
1910

2011
# my local models
2112
MODELZOO = {
@@ -34,8 +25,8 @@ def parse_arguments():
3425
parser = argparse.ArgumentParser(description='args for main.py')
3526

3627
parser.add_argument('--input', type=str, default="Suggest at least five related search terms to \"Mạng neural nhân tạo\".")
37-
parser.add_argument('--approx_model_name', type=str, default=MODELZOO["llama1b"])
38-
parser.add_argument('--target_model_name', type=str, default=MODELZOO["llama7b"])
28+
parser.add_argument('--approx_model_name', type=str, default=MODELZOO["bloom-560m"])
29+
parser.add_argument('--target_model_name', type=str, default=MODELZOO["bloom7b"])
3930
parser.add_argument('--verbose', '-v', action='store_true', default=False, help='enable verbose mode')
4031
parser.add_argument('--seed', '-s', type=int, default=None, help='set a random seed')
4132
args = parser.parse_args()
@@ -48,9 +39,8 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=40, ra
4839
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
4940

5041
tokenizer = AutoTokenizer.from_pretrained(approx_model_name, trust_remote_code=True)
51-
52-
global DECODER
53-
DECODER = Decoder(tokenizer)
42+
43+
Decoder().set_tokenizer(tokenizer)
5444

5545
print(f"begin loading models: \n {approx_model_name} \n {target_model_name}")
5646
small_model = AutoModelForCausalLM.from_pretrained(approx_model_name, trust_remote_code=True).to(torch_device)

sampling/speculative_sampling.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from sampling.kvcache_model import KVCacheModel
66
from sampling.utils import norm_logits, sample, max_fn
7+
from globals import Decoder
78

89
@torch.no_grad()
910
def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module,
@@ -37,9 +38,6 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
3738

3839
device = target_model.device
3940

40-
if verbose:
41-
global DECODER
42-
4341
approx_model_cache = KVCacheModel(approx_model, temperature, top_k, top_p, random_seed)
4442
target_model_cache = KVCacheModel(target_model, temperature, top_k, top_p, random_seed)
4543

@@ -62,7 +60,7 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
6260
break
6361

6462
if verbose:
65-
print(f"approx guess accepted {j[0]}: \033[31m{DECODER.decode(torch.tensor([j]))}\033[0m")
63+
print(f"approx guess accepted {j[0]}: \033[31m{Decoder().decode(torch.tensor([j]))}\033[0m")
6664

6765
# print(f"n : {n}, i : {i}, prefix_len + gamma - 1: {prefix_len + gamma - 1}")
6866
assert n >= prefix_len - 1, f"n {n}, prefix_len {prefix_len}"
@@ -76,15 +74,15 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
7674
# reject someone, sample from the pos n
7775
t = sample(max_fn(target_model_cache._prob_history[:, n, :] - approx_model_cache._prob_history[:, n, :]), random_seed=random_seed)
7876
if verbose:
79-
print(f"target resamples at position {n}: \033[34m{DECODER.decode(t)}\033[0m")
77+
print(f"target resamples at position {n}: \033[34m{Decoder().decode(t)}\033[0m")
8078

8179
target_model_cache.rollback(n+1)
8280
else:
8381
# all approx model decoding accepted
8482
assert n == target_model_cache._prob_history.shape[1] - 1
8583
t = sample(target_model_cache._prob_history[:, -1, :], random_seed=random_seed)
8684
if verbose:
87-
print(f"target samples {n}: \033[35m{DECODER.decode(t)}\033[0m")
85+
print(f"target samples {n}: \033[35m{Decoder().decode(t)}\033[0m")
8886
target_model_cache.rollback(n+2)
8987

9088

0 commit comments

Comments
 (0)