44
55from sampling .kvcache_model import KVCacheModel
66from sampling .utils import norm_logits , sample , max_fn
7+ from globals import Decoder
78
89@torch .no_grad ()
910def speculative_sampling (prefix : torch .Tensor , approx_model : torch .nn .Module , target_model : torch .nn .Module ,
@@ -37,9 +38,6 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
3738
3839 device = target_model .device
3940
40- if verbose :
41- global DECODER
42-
4341 approx_model_cache = KVCacheModel (approx_model , temperature , top_k , top_p , random_seed )
4442 target_model_cache = KVCacheModel (target_model , temperature , top_k , top_p , random_seed )
4543
@@ -62,7 +60,7 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
6260 break
6361
6462 if verbose :
65- print (f"approx guess accepted { j [0 ]} : \033 [31m{ DECODER .decode (torch .tensor ([j ]))} \033 [0m" )
63+ print (f"approx guess accepted { j [0 ]} : \033 [31m{ Decoder () .decode (torch .tensor ([j ]))} \033 [0m" )
6664
6765 # print(f"n : {n}, i : {i}, prefix_len + gamma - 1: {prefix_len + gamma - 1}")
6866 assert n >= prefix_len - 1 , f"n { n } , prefix_len { prefix_len } "
@@ -76,15 +74,15 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
7674 # reject someone, sample from the pos n
7775 t = sample (max_fn (target_model_cache ._prob_history [:, n , :] - approx_model_cache ._prob_history [:, n , :]), random_seed = random_seed )
7876 if verbose :
79- print (f"target resamples at position { n } : \033 [34m{ DECODER .decode (t )} \033 [0m" )
77+ print (f"target resamples at position { n } : \033 [34m{ Decoder () .decode (t )} \033 [0m" )
8078
8179 target_model_cache .rollback (n + 1 )
8280 else :
8381 # all approx model decoding accepted
8482 assert n == target_model_cache ._prob_history .shape [1 ] - 1
8583 t = sample (target_model_cache ._prob_history [:, - 1 , :], random_seed = random_seed )
8684 if verbose :
87- print (f"target samples { n } : \033 [35m{ DECODER .decode (t )} \033 [0m" )
85+ print (f"target samples { n } : \033 [35m{ Decoder () .decode (t )} \033 [0m" )
8886 target_model_cache .rollback (n + 2 )
8987
9088
0 commit comments