-
Notifications
You must be signed in to change notification settings - Fork 240
Feat (vLLM): initial export support #1444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 19 commits
fecfcb6
195443c
df68ed8
19aa9c9
aac450d
fb46fe6
1244425
69b1d49
6f544c6
ed6b8f1
7225614
2e94286
0a0c062
fd5edcc
67be3f8
b9ae23a
399363e
3a7ed83
c8716a7
79cc073
16d9e57
07910d6
30977f4
579101b
dbe37f0
709a59c
f775ce3
7de7488
76cf2f4
4e93f6f
f04594b
05078e7
71636c6
1e577f2
40e87d3
8488348
2ad0206
54a403d
ac313c3
fba05cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,3 +10,4 @@ pydantic | |
| torch>=2.4 | ||
| tqdm | ||
| transformers[sentencepiece]<5.0 | ||
| vllm | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| from typing import List | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from vllm.model_executor.layers.linear import LinearMethodBase | ||
|
|
||
| from brevitas.graph.hadamard import get_hadK | ||
| from brevitas.nn.equalized_layer import RotatedModule | ||
|
|
||
| from ..handler import FloatInferencetHandler | ||
| from ..handler import FloatWeightInferencetHandler | ||
| from ..handler import GroupwiseFloatInferenceHandler | ||
| from ..handler import GroupwiseFloatWeightInferenceHandler | ||
| from ..handler import IntInferencetHandler | ||
| from ..handler import IntWeightInferencetHandler | ||
|
|
||
| class_mapping = { | ||
| 'GroupwiseFloatInferenceHandler': GroupwiseFloatInferenceHandler, | ||
| 'GroupwiseFloatWeightInferenceHandler': GroupwiseFloatWeightInferenceHandler, | ||
| 'FloatInferencetHandler': FloatInferencetHandler, | ||
| 'FloatWeightInferencetHandler': FloatWeightInferencetHandler, | ||
| 'IntWeightInferencetHandler': IntWeightInferencetHandler, | ||
| 'IntInferencetHandler': IntInferencetHandler,} | ||
|
|
||
|
|
||
| class QuantLinear(LinearMethodBase): | ||
Giuseppe5 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self, | ||
| input_config=None, | ||
| weight_config=None, | ||
| bias_config=None, | ||
| output_config=None, | ||
| rotation_config=None): | ||
| self.input_quant = self.configure_proxy(input_config) | ||
| if isinstance(weight_config, list): | ||
| self.weight_quant = dict() | ||
| for i, config in enumerate(weight_config): | ||
| self.weight_quant[i] = self.configure_proxy(config) | ||
| else: | ||
| self.weight_quant = self.configure_proxy(weight_config) | ||
| self.bias_quant = self.configure_proxy(bias_config) | ||
| self.output_quant = self.configure_proxy(output_config) | ||
| self.rotation = self.configure_rotation(rotation_config) | ||
|
|
||
| def configure_rotation(self, rotation_config): | ||
| if rotation_config is None: | ||
| return torch.nn.Identity() | ||
| rot_mat_shape = rotation_config['rotation_size']['rot_mat_shape'] | ||
| k = rotation_config['rotation_size']['k'] | ||
| had_mat, _ = get_hadK(rot_mat_shape) | ||
| return RotatedModule(self, had_mat, k) | ||
|
|
||
| def configure_proxy(self, quant_config): | ||
| # No config, no quantizer | ||
| if quant_config is None: | ||
| return torch.nn.Identity() | ||
|
|
||
| # Extract element that are not part of the state dict | ||
| quant_class_name = quant_config['class_type'] | ||
| float_to_int_impl_type = quant_config['float_to_int_impl_type'] | ||
| del quant_config['class_type'] | ||
| del quant_config['float_to_int_impl_type'] | ||
|
|
||
| # Scale and zero-point are the only float elements in the state dict | ||
| for k, v in quant_config.items(): | ||
| if not isinstance(v, torch.Tensor): | ||
| if k == 'scale' or k == 'zero_point': | ||
| quant_config[k] = torch.tensor(v) | ||
| else: | ||
| quant_config[k] = torch.tensor(v, dtype=torch.int) | ||
|
|
||
| # Shapes must be set otherwise the state dict loading will fail | ||
| scale_shape = quant_config['scale'].shape | ||
| zero_point_shape = quant_config['zero_point'].shape | ||
| quant_class_type = class_mapping[quant_class_name] | ||
| quant_class = quant_class_type(scale_shape, zero_point_shape) | ||
|
|
||
| # Set the remaining attributes | ||
| quant_class.float_to_int_impl_type = float_to_int_impl_type | ||
| quant_class.load_state_dict(quant_config) | ||
| return quant_class | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: List[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| out_per_partition = sum(output_partition_sizes) | ||
| w = torch.empty( | ||
| (out_per_partition, input_size_per_partition), | ||
| device="cuda", | ||
| dtype=params_dtype, | ||
| ) | ||
|
|
||
| layer.weight = torch.nn.Parameter(w, requires_grad=False) | ||
|
|
||
| # Handling the packed weights for loading | ||
| base_loader = extra_weight_attrs.get("weight_loader", None) | ||
|
|
||
| def packed_weight_loader(param, loaded_weight, loaded_shard_id=None, *args, **kwargs): | ||
|
|
||
| if loaded_shard_id is not None: | ||
| if isinstance(loaded_shard_id, int): | ||
| _loaded_shard_id = loaded_shard_id | ||
| else: | ||
| if loaded_shard_id == "q": | ||
| _loaded_shard_id = 0 | ||
| elif loaded_shard_id == "k": | ||
| _loaded_shard_id = 1 | ||
| elif loaded_shard_id == "v": | ||
| _loaded_shard_id = 2 | ||
| else: | ||
| raise ValueError(f"Invalid loaded_shard_id: {loaded_shard_id}") | ||
|
|
||
| logical_widths = list(output_partition_sizes) | ||
| start_idx = sum(logical_widths[:_loaded_shard_id]) | ||
| end_idx = start_idx + logical_widths[_loaded_shard_id] | ||
| weight_quant = self.weight_quant[_loaded_shard_id] | ||
| else: | ||
| start_idx = 0 | ||
| end_idx = out_per_partition | ||
| weight_quant = self.weight_quant | ||
| if weight_quant is not None: | ||
| loaded_weight = weight_quant(loaded_weight.cuda())[0].cpu() | ||
|
|
||
| if base_loader is not None: | ||
| return base_loader(param[start_idx:end_idx], loaded_weight, *args, **kwargs) | ||
| param[start_idx:end_idx].data.copy_(loaded_weight) | ||
|
|
||
| setattr(layer.weight, "weight_loader", packed_weight_loader) | ||
|
|
||
| # If this layer has bias, allocate it | ||
| if getattr(layer, "bias", None) is not None: | ||
| b = torch.empty((out_per_partition,), device="cuda", dtype=params_dtype) | ||
| layer.bias = torch.nn.Parameter(b, requires_grad=False) | ||
| base_bias_loader = extra_weight_attrs.get("bias_loader", None) | ||
|
|
||
| def packed_bias_loader(param, loaded_bias, *args, **kwargs): | ||
| if isinstance(loaded_bias, (list, tuple)): | ||
| loaded_bias = torch.cat(list(loaded_bias), dim=0) | ||
| if base_bias_loader is not None: | ||
| return base_bias_loader(param, loaded_bias, *args, **kwargs) | ||
| param.data.copy_(loaded_bias) | ||
|
|
||
| setattr(layer.bias, "bias_loader", packed_bias_loader) | ||
|
|
||
| # Preserve attrs that vLLM weight loaders may attach | ||
| for k, v in extra_weight_attrs.items(): | ||
| if k in ("weight_loader", "bias_loader"): | ||
| continue | ||
| setattr(layer.weight, k, v) | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| # x = self.rotation.rotation_forward(x) | ||
| x = self.input_quant(x) | ||
| bias = self.bias_quant(bias) if bias is not None else None | ||
| y = x.matmul(layer.weight.t()) | ||
| if bias is not None: | ||
| y = y + bias | ||
| y = self.output_quant(y) | ||
| return y | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from dataclasses import dataclass | ||
| from functools import partial | ||
| import json | ||
| import os | ||
| from typing import Any | ||
| from typing import List | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch.nn import Module | ||
| import torch.nn as nn | ||
| from vllm.model_executor.layers.linear import LinearBase | ||
| from vllm.model_executor.layers.linear import MergedColumnParallelLinear | ||
| from vllm.model_executor.layers.linear import QKVParallelLinear | ||
| from vllm.model_executor.layers.linear import RowParallelLinear | ||
| from vllm.model_executor.layers.linear import UnquantizedLinearMethod | ||
| from vllm.model_executor.layers.quantization import QuantizationMethods | ||
| from vllm.model_executor.layers.quantization import register_quantization_config | ||
| from vllm.model_executor.layers.quantization.base_config import QuantizationConfig | ||
|
|
||
| import brevitas.config as config | ||
| from brevitas.export.inference.vLLM.handler import QuantLinear | ||
| from brevitas.nn.equalized_layer import EqualizedModule | ||
| from brevitas.nn.equalized_layer import RotatedModule | ||
| from brevitas.nn.mixin import QuantLayerMixin | ||
| from brevitas.proxy.quant_proxy import QuantProxyFromInjector | ||
|
|
||
| from ..manager import _override_act_caching_mode | ||
| from ..manager import _override_bias_caching_mode | ||
| from ..manager import _override_create_quant_tensor | ||
| from ..manager import _override_weight_caching_mode | ||
|
|
||
|
|
||
| @register_quantization_config("quant_brevitas") | ||
| @dataclass | ||
| class QuantConfigBrevitas(QuantizationConfig): | ||
|
|
||
| def __init__(self, ignored_layers: list[str] | None = None, config: str | None = None): | ||
|
||
| super().__init__() | ||
| self.ignored_layers = ignored_layers | ||
| self.config = config | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: dict[str, Any]) -> "QuantConfigTcast": | ||
Giuseppe5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return cls(config=config) | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| # Minimum GPU compute capability needed for the kernel. | ||
| return 0 | ||
|
|
||
| @classmethod | ||
| def get_name(cls) -> QuantizationMethods: | ||
| return "quant_brevitas" | ||
|
|
||
| @classmethod | ||
| def get_supported_act_dtypes(cls) -> list[torch.dtype]: | ||
| return [torch.float16, torch.bfloat16, torch.float32] | ||
|
|
||
| @staticmethod | ||
| def get_config_filenames() -> list[str]: | ||
| return ["brevitas_config.json"] | ||
|
|
||
| def get_quant_method(self, layer: torch.nn.Module, | ||
| prefix: str) -> Optional["QuantizeMethodBase"]: | ||
Giuseppe5 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if isinstance(layer, RowParallelLinear) or isinstance( | ||
| layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear): | ||
| if self.ignored_layers and is_layer_skipped( | ||
Giuseppe5 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| prefix=prefix, | ||
| ignored_layers=self.ignored_layers, | ||
| fused_mapping=self.packed_modules_mapping, | ||
| ): | ||
| return UnquantizedLinearMethod() | ||
| else: | ||
| if prefix in self.config: | ||
| base_config = self.config[prefix] | ||
| input_config = base_config.get('input_quant', None) | ||
| bias_config = base_config.get('bias_quant', None) | ||
| output_config = base_config.get('output_quant', None) | ||
| weight_config = base_config.get('weight_quant', None) | ||
| else: | ||
| base = prefix.split('.')[:-1] | ||
| base = '.'.join(base) | ||
| suffix = prefix.split('.')[-1] | ||
| layers_to_merge = self.packed_modules_mapping[suffix] | ||
| layers_to_merge = [base + '.' + x for x in layers_to_merge] | ||
|
|
||
| base_config = self.config[layers_to_merge[0]] | ||
| input_config = base_config.get('input_quant', None) | ||
| bias_config = base_config.get('bias_quant', None) | ||
| output_config = base_config.get('output_quant', None) | ||
| weight_config = [ | ||
| self.config[layer].get('weight_quant', None) for layer in layers_to_merge] | ||
| # base_config = combine_configs(self.config, *layers_to_merge) | ||
|
|
||
| return QuantLinear( | ||
| input_config=input_config, | ||
| bias_config=bias_config, | ||
| output_config=output_config, | ||
| weight_config=weight_config) | ||
|
|
||
| elif isinstance(layer, LinearBase): | ||
| return UnquantizedLinearMethod() | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def combine_configs(config, *names): | ||
| base_config = config[names[0]] | ||
| scale = None #base_config['scale'] | ||
| for n in names: | ||
| if scale is None: | ||
| scale = torch.tensor(config[n]['weight_quant']['scale']) | ||
| else: | ||
| v = torch.tensor(config[n]['weight_quant']['scale']) | ||
| scale = torch.cat((scale, v), 0) | ||
| base_config['weight_quant']['scale'] = scale | ||
| return base_config | ||
|
|
||
|
|
||
| from json import JSONEncoder | ||
|
|
||
| from torch.utils.data import Dataset | ||
|
|
||
|
|
||
| class EncodeTensor(JSONEncoder, Dataset): | ||
|
|
||
| def default(self, obj): | ||
| if isinstance(obj, torch.Tensor): | ||
| if obj.dtype == torch.bfloat16: | ||
| obj = obj.to(torch.float32) | ||
| return obj.cpu().detach().numpy().tolist() | ||
| return super(EncodeTensor, self).default(obj) | ||
|
|
||
|
|
||
| class vLLMExportManager(): | ||
|
|
||
| wrap_layers = (EqualizedModule, RotatedModule) | ||
|
|
||
| def export(self, model, filepath): | ||
| json_filename = os.path.join(filepath, 'brevitas_config.json') | ||
| config.IGNORE_EXPORT_KEYS = False | ||
| json_to_save = dict() | ||
| proxies_ckpts = os.path.join(filepath, 'brevitas_proxies') | ||
| os.makedirs(proxies_ckpts, exist_ok=True) | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, QuantLayerMixin) or isinstance(module, self.wrap_layers): | ||
| layer_dict = dict() | ||
| json_to_save[name] = layer_dict | ||
| for subname, submodule in module.named_children(): | ||
| if isinstance(submodule, QuantProxyFromInjector) and submodule.is_quant_enabled: | ||
| proxy_dict = dict() | ||
| json_to_save[name][subname] = proxy_dict | ||
| export_handler = submodule.export_handler | ||
| # torch.save(export_handler.state_dict(), ckpt_path) | ||
| proxy_dict.update(export_handler.state_dict()) | ||
| proxy_dict['float_to_int_impl_type'] = export_handler.float_to_int_impl_type | ||
| proxy_dict['class_type'] = export_handler.__class__.__name__ | ||
| if isinstance(module, self.wrap_layers): | ||
| layer_dict['rotation_config'] = dict() | ||
| layer_dict['rotation_config']['rot_mat_shape'] = module.had_mat.shape[ | ||
| 0] if module.had_mat is not None else None | ||
| layer_dict['rotation_config']['k'] = module.k | ||
|
|
||
| with open(json_filename, 'w') as f: | ||
| json.dump(json_to_save, f, cls=EncodeTensor) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like vLLM should be an optional dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can do it in a similar way to what we did for lighteval/lm_eval
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm leaving it for now so that test run and I can see what other things I'm breaking in the process, but I'll remove before this PR is merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with doing it similarly as for lighteval/lm_eval