Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
fecfcb6
Fix
Giuseppe5 Jan 27, 2026
195443c
Feat (vLLM): initial export support
Giuseppe5 Jan 24, 2026
df68ed8
Cleanup
Giuseppe5 Jan 24, 2026
19aa9c9
More cleanup
Giuseppe5 Jan 24, 2026
aac450d
More bugfix, cleanup
Giuseppe5 Jan 24, 2026
fb46fe6
More cleanup and fixes
Giuseppe5 Jan 24, 2026
1244425
Removed too much stuff
Giuseppe5 Jan 24, 2026
69b1d49
temp
Giuseppe5 Jan 26, 2026
6f544c6
Temp 2
Giuseppe5 Jan 26, 2026
ed6b8f1
cleanup
Giuseppe5 Jan 27, 2026
7225614
requirements
Giuseppe5 Jan 27, 2026
2e94286
import
Giuseppe5 Jan 27, 2026
0a0c062
import 2
Giuseppe5 Jan 27, 2026
fd5edcc
Fix init
Giuseppe5 Jan 27, 2026
67be3f8
fix init 2
Giuseppe5 Jan 27, 2026
b9ae23a
Fix proxies
Giuseppe5 Jan 28, 2026
399363e
Update quantize.py
Giuseppe5 Jan 28, 2026
3a7ed83
Update main.py
Giuseppe5 Jan 28, 2026
c8716a7
sync
Giuseppe5 Jan 28, 2026
79cc073
Fix
Giuseppe5 Jan 29, 2026
16d9e57
Fix
Giuseppe5 Feb 3, 2026
07910d6
fixes
Giuseppe5 Feb 3, 2026
30977f4
fix
Giuseppe5 Feb 4, 2026
579101b
small fixes
Giuseppe5 Feb 4, 2026
dbe37f0
item not needed
Giuseppe5 Feb 4, 2026
709a59c
precommit
Giuseppe5 Feb 4, 2026
f775ce3
Update
Giuseppe5 Feb 4, 2026
7de7488
fix
Giuseppe5 Feb 4, 2026
76cf2f4
temp
Giuseppe5 Feb 17, 2026
4e93f6f
Simplified interface
Giuseppe5 Feb 18, 2026
f04594b
refactor function
Giuseppe5 Feb 20, 2026
05078e7
review
Giuseppe5 Feb 20, 2026
71636c6
small fix
Giuseppe5 Feb 20, 2026
1e577f2
missing layer
Giuseppe5 Feb 20, 2026
40e87d3
remove useless code
Giuseppe5 Feb 20, 2026
8488348
review
Giuseppe5 Feb 20, 2026
2ad0206
fix tests
Giuseppe5 Feb 21, 2026
54a403d
precommit
Giuseppe5 Feb 21, 2026
ac313c3
fix import
Giuseppe5 Feb 21, 2026
fba05cf
typing is hard
Giuseppe5 Feb 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pydantic
torch>=2.4
tqdm
transformers[sentencepiece]<5.0
vllm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like vLLM should be an optional dependency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can do it in a similar way to what we did for lighteval/lm_eval

Copy link
Collaborator Author

@Giuseppe5 Giuseppe5 Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaving it for now so that test run and I can see what other things I'm breaking in the process, but I'll remove before this PR is merged

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with doing it similarly as for lighteval/lm_eval

1 change: 1 addition & 0 deletions src/brevitas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def env_to_bool(name, default):

REINIT_ON_STATE_DICT_LOAD = env_to_bool('BREVITAS_REINIT_ON_STATE_DICT_LOAD', True)
IGNORE_MISSING_KEYS = env_to_bool('BREVITAS_IGNORE_MISSING_KEYS', False)
IGNORE_EXPORT_KEYS = env_to_bool('BREVITAS_IGNORE_EXPORT_KEYS', True)
# JIT_ENABLED triggers NATIVE_STE_BACKEND_ENABLED to True, but not the other way around
JIT_ENABLED = env_to_bool('BREVITAS_JIT', False) and _enabled
NATIVE_STE_BACKEND_ENABLED = env_to_bool('BREVITAS_NATIVE_STE_BACKEND', False)
Expand Down
338 changes: 240 additions & 98 deletions src/brevitas/export/inference/handler.py

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from functools import partial

from packaging import version
import torch
from torch.nn import Module
import torch.nn as nn

from brevitas import torch_version
from brevitas.export.inference.handler import DynamicFloatInferenceHandler
from brevitas.export.inference.handler import DynamicIntInferenceHandler
from brevitas.export.inference.handler import FloatInferencetHandler
Expand Down Expand Up @@ -85,7 +83,7 @@ def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
InferenceManager.set_export_mode(self.model, enabled=False)
self.export_manager.set_export_mode(self.model, enabled=False)
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
Expand All @@ -105,8 +103,8 @@ def hook(self, module, inp, out):
# - Disable return quant tensor since all quant metadata is cached
assert len(self.hook_list) == 1
self.hook_list[0].remove()
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.model.apply(self.export_manager.set_export_handler)
self.export_manager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = QuantizationStatusManager.disable_return_quant_tensor(
self.model)
disable_quant_tensor = partial(_override_create_quant_tensor, state=True)
Expand Down
Empty file.
172 changes: 172 additions & 0 deletions src/brevitas/export/inference/vLLM/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import List
from typing import Optional

import torch
from vllm.model_executor.layers.linear import LinearMethodBase

from brevitas.graph.hadamard import get_hadK
from brevitas.nn.equalized_layer import RotatedModule

from ..handler import FloatInferencetHandler
from ..handler import FloatWeightInferencetHandler
from ..handler import GroupwiseFloatInferenceHandler
from ..handler import GroupwiseFloatWeightInferenceHandler
from ..handler import IntInferencetHandler
from ..handler import IntWeightInferencetHandler

class_mapping = {
'GroupwiseFloatInferenceHandler': GroupwiseFloatInferenceHandler,
'GroupwiseFloatWeightInferenceHandler': GroupwiseFloatWeightInferenceHandler,
'FloatInferencetHandler': FloatInferencetHandler,
'FloatWeightInferencetHandler': FloatWeightInferencetHandler,
'IntWeightInferencetHandler': IntWeightInferencetHandler,
'IntInferencetHandler': IntInferencetHandler,}


class QuantLinear(LinearMethodBase):

def __init__(
self,
input_config=None,
weight_config=None,
bias_config=None,
output_config=None,
rotation_config=None):
self.input_quant = self.configure_proxy(input_config)
if isinstance(weight_config, list):
self.weight_quant = dict()
for i, config in enumerate(weight_config):
self.weight_quant[i] = self.configure_proxy(config)
else:
self.weight_quant = self.configure_proxy(weight_config)
self.bias_quant = self.configure_proxy(bias_config)
self.output_quant = self.configure_proxy(output_config)
self.rotation = self.configure_rotation(rotation_config)

def configure_rotation(self, rotation_config):
if rotation_config is None:
return torch.nn.Identity()
rot_mat_shape = rotation_config['rotation_size']['rot_mat_shape']
k = rotation_config['rotation_size']['k']
had_mat, _ = get_hadK(rot_mat_shape)
return RotatedModule(self, had_mat, k)

def configure_proxy(self, quant_config):
# No config, no quantizer
if quant_config is None:
return torch.nn.Identity()

# Extract element that are not part of the state dict
quant_class_name = quant_config['class_type']
float_to_int_impl_type = quant_config['float_to_int_impl_type']
del quant_config['class_type']
del quant_config['float_to_int_impl_type']

# Scale and zero-point are the only float elements in the state dict
for k, v in quant_config.items():
if not isinstance(v, torch.Tensor):
if k == 'scale' or k == 'zero_point':
quant_config[k] = torch.tensor(v)
else:
quant_config[k] = torch.tensor(v, dtype=torch.int)

# Shapes must be set otherwise the state dict loading will fail
scale_shape = quant_config['scale'].shape
zero_point_shape = quant_config['zero_point'].shape
quant_class_type = class_mapping[quant_class_name]
quant_class = quant_class_type(scale_shape, zero_point_shape)

# Set the remaining attributes
quant_class.float_to_int_impl_type = float_to_int_impl_type
quant_class.load_state_dict(quant_config)
return quant_class

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
out_per_partition = sum(output_partition_sizes)
w = torch.empty(
(out_per_partition, input_size_per_partition),
device="cuda",
dtype=params_dtype,
)

layer.weight = torch.nn.Parameter(w, requires_grad=False)

# Handling the packed weights for loading
base_loader = extra_weight_attrs.get("weight_loader", None)

def packed_weight_loader(param, loaded_weight, loaded_shard_id=None, *args, **kwargs):

if loaded_shard_id is not None:
if isinstance(loaded_shard_id, int):
_loaded_shard_id = loaded_shard_id
else:
if loaded_shard_id == "q":
_loaded_shard_id = 0
elif loaded_shard_id == "k":
_loaded_shard_id = 1
elif loaded_shard_id == "v":
_loaded_shard_id = 2
else:
raise ValueError(f"Invalid loaded_shard_id: {loaded_shard_id}")

logical_widths = list(output_partition_sizes)
start_idx = sum(logical_widths[:_loaded_shard_id])
end_idx = start_idx + logical_widths[_loaded_shard_id]
weight_quant = self.weight_quant[_loaded_shard_id]
else:
start_idx = 0
end_idx = out_per_partition
weight_quant = self.weight_quant
if weight_quant is not None:
loaded_weight = weight_quant(loaded_weight.cuda())[0].cpu()

if base_loader is not None:
return base_loader(param[start_idx:end_idx], loaded_weight, *args, **kwargs)
param[start_idx:end_idx].data.copy_(loaded_weight)

setattr(layer.weight, "weight_loader", packed_weight_loader)

# If this layer has bias, allocate it
if getattr(layer, "bias", None) is not None:
b = torch.empty((out_per_partition,), device="cuda", dtype=params_dtype)
layer.bias = torch.nn.Parameter(b, requires_grad=False)
base_bias_loader = extra_weight_attrs.get("bias_loader", None)

def packed_bias_loader(param, loaded_bias, *args, **kwargs):
if isinstance(loaded_bias, (list, tuple)):
loaded_bias = torch.cat(list(loaded_bias), dim=0)
if base_bias_loader is not None:
return base_bias_loader(param, loaded_bias, *args, **kwargs)
param.data.copy_(loaded_bias)

setattr(layer.bias, "bias_loader", packed_bias_loader)

# Preserve attrs that vLLM weight loaders may attach
for k, v in extra_weight_attrs.items():
if k in ("weight_loader", "bias_loader"):
continue
setattr(layer.weight, k, v)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# x = self.rotation.rotation_forward(x)
x = self.input_quant(x)
bias = self.bias_quant(bias) if bias is not None else None
y = x.matmul(layer.weight.t())
if bias is not None:
y = y + bias
y = self.output_quant(y)
return y
169 changes: 169 additions & 0 deletions src/brevitas/export/inference/vLLM/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass
from functools import partial
import json
import os
from typing import Any
from typing import List
from typing import Optional

import torch
from torch.nn import Module
import torch.nn as nn
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

import brevitas.config as config
from brevitas.export.inference.vLLM.handler import QuantLinear
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.nn.mixin import QuantLayerMixin
from brevitas.proxy.quant_proxy import QuantProxyFromInjector

from ..manager import _override_act_caching_mode
from ..manager import _override_bias_caching_mode
from ..manager import _override_create_quant_tensor
from ..manager import _override_weight_caching_mode


@register_quantization_config("quant_brevitas")
@dataclass
class QuantConfigBrevitas(QuantizationConfig):

def __init__(self, ignored_layers: list[str] | None = None, config: str | None = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this config: str | None = None the correct typing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual typing is much more complicated than Dict, not sure if we should be super explicit

super().__init__()
self.ignored_layers = ignored_layers
self.config = config

@classmethod
def from_config(cls, config: dict[str, Any]) -> "QuantConfigTcast":
return cls(config=config)

@classmethod
def get_min_capability(cls) -> int:
# Minimum GPU compute capability needed for the kernel.
return 0

@classmethod
def get_name(cls) -> QuantizationMethods:
return "quant_brevitas"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]

@staticmethod
def get_config_filenames() -> list[str]:
return ["brevitas_config.json"]

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, RowParallelLinear) or isinstance(
layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
else:
if prefix in self.config:
base_config = self.config[prefix]
input_config = base_config.get('input_quant', None)
bias_config = base_config.get('bias_quant', None)
output_config = base_config.get('output_quant', None)
weight_config = base_config.get('weight_quant', None)
else:
base = prefix.split('.')[:-1]
base = '.'.join(base)
suffix = prefix.split('.')[-1]
layers_to_merge = self.packed_modules_mapping[suffix]
layers_to_merge = [base + '.' + x for x in layers_to_merge]

base_config = self.config[layers_to_merge[0]]
input_config = base_config.get('input_quant', None)
bias_config = base_config.get('bias_quant', None)
output_config = base_config.get('output_quant', None)
weight_config = [
self.config[layer].get('weight_quant', None) for layer in layers_to_merge]
# base_config = combine_configs(self.config, *layers_to_merge)

return QuantLinear(
input_config=input_config,
bias_config=bias_config,
output_config=output_config,
weight_config=weight_config)

elif isinstance(layer, LinearBase):
return UnquantizedLinearMethod()

return None


def combine_configs(config, *names):
base_config = config[names[0]]
scale = None #base_config['scale']
for n in names:
if scale is None:
scale = torch.tensor(config[n]['weight_quant']['scale'])
else:
v = torch.tensor(config[n]['weight_quant']['scale'])
scale = torch.cat((scale, v), 0)
base_config['weight_quant']['scale'] = scale
return base_config


from json import JSONEncoder

from torch.utils.data import Dataset


class EncodeTensor(JSONEncoder, Dataset):

def default(self, obj):
if isinstance(obj, torch.Tensor):
if obj.dtype == torch.bfloat16:
obj = obj.to(torch.float32)
return obj.cpu().detach().numpy().tolist()
return super(EncodeTensor, self).default(obj)


class vLLMExportManager():

wrap_layers = (EqualizedModule, RotatedModule)

def export(self, model, filepath):
json_filename = os.path.join(filepath, 'brevitas_config.json')
config.IGNORE_EXPORT_KEYS = False
json_to_save = dict()
proxies_ckpts = os.path.join(filepath, 'brevitas_proxies')
os.makedirs(proxies_ckpts, exist_ok=True)
for name, module in model.named_modules():
if isinstance(module, QuantLayerMixin) or isinstance(module, self.wrap_layers):
layer_dict = dict()
json_to_save[name] = layer_dict
for subname, submodule in module.named_children():
if isinstance(submodule, QuantProxyFromInjector) and submodule.is_quant_enabled:
proxy_dict = dict()
json_to_save[name][subname] = proxy_dict
export_handler = submodule.export_handler
# torch.save(export_handler.state_dict(), ckpt_path)
proxy_dict.update(export_handler.state_dict())
proxy_dict['float_to_int_impl_type'] = export_handler.float_to_int_impl_type
proxy_dict['class_type'] = export_handler.__class__.__name__
if isinstance(module, self.wrap_layers):
layer_dict['rotation_config'] = dict()
layer_dict['rotation_config']['rot_mat_shape'] = module.had_mat.shape[
0] if module.had_mat is not None else None
layer_dict['rotation_config']['k'] = module.k

with open(json_filename, 'w') as f:
json.dump(json_to_save, f, cls=EncodeTensor)
Loading
Loading