diff --git a/models/llama3/model.py b/models/llama3/model.py index 57bb33799..c7601c63a 100644 --- a/models/llama3/model.py +++ b/models/llama3/model.py @@ -144,7 +144,7 @@ def __init__(self, args: ModelArgs): init_method=lambda x: x, ) - self.cache_k = torch.zeros( + cache_k = torch.zeros( ( args.max_batch_size, args.max_seq_len, @@ -152,7 +152,7 @@ def __init__(self, args: ModelArgs): self.head_dim, ) ) - self.cache_v = torch.zeros( + cache_v = torch.zeros( ( args.max_batch_size, args.max_seq_len, @@ -160,6 +160,8 @@ def __init__(self, args: ModelArgs): self.head_dim, ) ) + self.register_buffer("cache_k", cache_k, persistent=False) + self.register_buffer("cache_v", cache_v, persistent=False) def forward( self, diff --git a/models/llama4/model.py b/models/llama4/model.py index 75f281d5c..306365acc 100644 --- a/models/llama4/model.py +++ b/models/llama4/model.py @@ -154,7 +154,7 @@ def __init__( init_method=lambda x: x, ) - self.cache_k = torch.zeros( + cache_k = torch.zeros( ( args.max_batch_size, args.max_seq_len, @@ -162,7 +162,7 @@ def __init__( self.head_dim, ) ).cuda() - self.cache_v = torch.zeros( + cache_v = torch.zeros( ( args.max_batch_size, args.max_seq_len, @@ -170,6 +170,8 @@ def __init__( self.head_dim, ) ).cuda() + self.register_buffer("cache_k", cache_k, persistent=False) + self.register_buffer("cache_v", cache_v, persistent=False) self.norm_eps = args.norm_eps self._register_load_state_dict_pre_hook(self.load_hook)