Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions models/llama4/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import sys
import time
from packaging import version
from pathlib import Path
from typing import Callable, Generator, List, Optional

Expand All @@ -33,6 +34,12 @@
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])


def is_xccl_available():
if version.parse(torch.__version__).release >= version.parse("2.7").release:
return torch.distributed.distributed_c10d.is_xccl_available()
return False


class Llama4:
@staticmethod
def build(
Expand All @@ -42,17 +49,35 @@ def build(
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
device = torch.device(device)
if (
device.type == "cuda"
and not torch.cuda.is_available()
or device.type == "xpu"
and not torch.xpu.is_available()
):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")

if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if device.type == "cuda":
torch.distributed.init_process_group("nccl")
elif device.type == "xpu" and is_xccl_available():
torch.distributed.init_process_group("xccl")
else:
torch.distributed.init_process_group("gloo")

if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if device.type == "cuda":
torch.cuda.set_device(local_rank)
elif device.type == "xpu":
torch.xpu.set_device(local_rank)

torch.manual_seed(seed)

Expand Down Expand Up @@ -96,15 +121,24 @@ def build(
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)

model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama4(model, tokenizer, model_args)
Expand Down Expand Up @@ -152,13 +186,13 @@ def generate(
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id

if echo:
Expand All @@ -178,7 +212,7 @@ def generate(
)
yield results

stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)

prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
Expand Down
4 changes: 2 additions & 2 deletions models/llama4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,15 @@ def __init__(
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
)
self.norm_eps = args.norm_eps
self._register_load_state_dict_pre_hook(self.load_hook)

Expand Down
14 changes: 14 additions & 0 deletions models/llama4/scripts/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,22 @@
from models.datatypes import RawMediaItem, RawMessage, RawTextItem, StopReason
from models.llama4.generation import Llama4

import os
import torch

THIS_DIR = Path(__file__).parent


def get_device():
if "DEVICE" in os.environ:
return os.environ["DEVICE"]
if torch.cuda.is_available():
return "cuda"
elif torch.xpu.is_available():
return "xpu"
return "cpu"


def run_main(
checkpoint_dir: str,
world_size: int = 1,
Expand All @@ -36,6 +49,7 @@ def run_main(
max_batch_size=max_batch_size,
world_size=world_size,
quantization_mode=quantization_mode,
device=get_device(),
)

dialogs = [
Expand Down
14 changes: 14 additions & 0 deletions models/llama4/scripts/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,22 @@
from models.datatypes import RawMediaItem
from models.llama4.generation import Llama4

import os
import torch

THIS_DIR = Path(__file__).parent


def get_device():
if "DEVICE" in os.environ:
return os.environ["DEVICE"]
if torch.cuda.is_available():
return "cuda"
elif torch.xpu.is_available():
return "xpu"
return "cpu"


def run_main(
checkpoint_dir: str,
world_size: int = 1,
Expand All @@ -36,6 +49,7 @@ def run_main(
max_batch_size=max_batch_size,
world_size=world_size,
quantization_mode=quantization_mode,
device=get_device(),
)

with open(THIS_DIR / "../../resources/dog.jpg", "rb") as f:
Expand Down