Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions paddleformers/cli/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ def get_lora_target_modules(model):
"model.language_model.*k_proj.*",
"model.language_model.*v_proj.*",
"model.language_model.*o_proj.*",
"model.language_model.*gate_up_proj.*",
"model.language_model.*down_proj.*",
"model.language_model.*mlp.experts",
# Vision
"model.visual.blocks.*attn.qkv.*",
"model.visual.blocks.*attn.proj.*",
Expand Down
124 changes: 124 additions & 0 deletions paddleformers/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,130 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class LoRAExperts(nn.Layer):
def __init__(
self,
base_layer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
**kwargs
):
super().__init__(**kwargs)
self.base_layer = base_layer
self.num_experts = base_layer.num_experts
self.act_fn = base_layer.act_fn
self.r = r
self.lora_alpha = lora_alpha
self.merged = False
self.disable_lora = False
self.lora_plus_scale = lora_plus_scale

self.gate_up_proj, self.gate_up_proj_lora_A, self.gate_up_proj_lora_B = self._init_lora("gate_up_proj")
self.down_proj, self.down_proj_lora_A, self.down_proj_lora_B = self._init_lora("down_proj")

if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

def _init_lora(self, parameter_name: str):
if not hasattr(self.base_layer, parameter_name):
raise ValueError(f"Parameter '{parameter_name}' does not exist in the base layer.")

parameter = getattr(self.base_layer, parameter_name)
parameter.stop_gradient = True
num_experts, in_features, out_features = parameter.shape
lora_A = self.create_parameter(
shape=[num_experts, in_features, self.r],
dtype=paddle.get_default_dtype(),
is_bias=False,
default_initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu"),
)
lora_B = self.create_parameter(
shape=[num_experts, self.r, out_features],
dtype=paddle.get_default_dtype(),
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=self.lora_plus_scale,
),
)

return (parameter, lora_A, lora_B)

def get_delta_weight(self, lora_A=None, lora_B=None):
lora_A = lora_A if lora_A is not None else self.lora_A
lora_B = lora_B if lora_B is not None else self.lora_B
delta_weight = lora_A @ lora_B * self.scaling

return delta_weight

def merge(self):
if not self.merged:
delta_weight = self.get_delta_weight(self.gate_up_proj_lora_A, self.gate_up_proj_lora_B)
new_parameter = self.gate_up_proj + delta_weight
self.gate_up_proj.set_value(new_parameter)
delta_weight = self.get_delta_weight(self.down_proj_lora_A, self.down_proj_lora_B)
new_parameter = self.down_proj + delta_weight
self.down_proj.set_value(new_parameter)
self.merged = True

def unmerge(self):
if self.merged:
delta_weight = self.get_delta_weight(self.gate_up_proj_lora_A, self.gate_up_proj_lora_B)
new_parameter = self.gate_up_proj - delta_weight
self.gate_up_proj.set_value(new_parameter)
delta_weight = self.get_delta_weight(self.down_proj_lora_A, self.down_proj_lora_B)
new_parameter = self.down_proj - delta_weight
self.down_proj.set_value(new_parameter)
self.merged = False

def forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = paddle.zeros_like(hidden_states)
with paddle.no_grad():
expert_mask = paddle.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = paddle.greater(expert_mask.sum(dim=(-1, -2)), paddle.to_tensor(0, dtype="int32")).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = paddle.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
if not (self.disable_lora or self.merged):
delta_state = (
current_state
@ self.gate_up_proj_lora_A[expert_idx]
@ self.gate_up_proj_lora_B[expert_idx]
* self.scaling
)
current_state = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]) + delta_state
else:
current_state = nn.functional.linear(current_state, self.gate_up_proj[expert_idx])
gate, up = current_state.chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
if not (self.disable_lora or self.merged):
delta_states = (
current_hidden_states
@ self.down_proj_lora_A[expert_idx]
@ self.down_proj_lora_B[expert_idx]
* self.scaling
)
current_hidden_states = (
nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + delta_states
)
else:
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states


class FleetLoRALinear(LoRALinear):
def __init__(self, in_features, out_features, skip_bias_add, **kwargs):
super().__init__(in_features, out_features, **kwargs)
Expand Down
18 changes: 15 additions & 3 deletions paddleformers/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_lora_layers():
XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear,
)

from .lora_layers import LoRAConv2D
from .lora_layers import LoRAConv2D, LoRAExperts
else:
raise ImportError # Force to use the fallback if not XPU
except ImportError:
Expand All @@ -121,6 +121,7 @@ def get_lora_layers():
FleetRowParallelLoRALinear,
FleetRowSequenceParallelLoRALinear,
LoRAConv2D,
LoRAExperts,
LoRALinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
Expand All @@ -131,6 +132,7 @@ def get_lora_layers():
"ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear,
"LoRAConv2D": LoRAConv2D,
"LoRALinear": LoRALinear,
"LoRAExperts": LoRAExperts,
"RowParallelLoRALinear": RowParallelLoRALinear,
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
"FleetLoRALinear": FleetLoRALinear,
Expand All @@ -145,6 +147,7 @@ def get_lora_layers():
ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"]
ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"]
LoRAConv2D = lora_layers["LoRAConv2D"]
LoRAExperts = lora_layers["LoRAExperts"]
LoRALinear = lora_layers["LoRALinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]
Expand Down Expand Up @@ -904,6 +907,15 @@ def _find_and_replace_module(self, model, module_name, lora_config):
lora_module = RowParallelQuantizationLoRALinear(module, lora_config)
# Lora row parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
elif attribute_chain[-1] == "experts":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.这个匹配规则是否具有通用性,会不会替换其他存量模型导致问题?
2. 需要考虑如果模型的expert写法比较特殊能够流一个接口适配自定义的loraexpert
3.是否能够匹配paddlefleet的expert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改匹配规则,并且保留接口用于适配自定义的lora expert

lora_module = LoRAExperts(
module,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
)
if lora_module is None:
raise ValueError(
f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddleformers.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。"
Expand Down Expand Up @@ -968,6 +980,7 @@ def mark_only_lora_as_trainable(self) -> None:
or isinstance(layer, FleetColumnSequenceParallelLoRALinear)
or isinstance(layer, RowSequenceParallelLoRALinear)
or isinstance(layer, FleetRowSequenceParallelLoRALinear)
or isinstance(layer, LoRAExperts)
or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear))
or (
ColumnParallelQuantizationLoRALinear is not None
Expand Down Expand Up @@ -1004,8 +1017,7 @@ def get_lora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: L
return model
if isinstance(lora_config.target_modules, str):
lora_config.target_modules = [lora_config.target_modules]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要添加相关单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加相关单测

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要适配get_merge_state_dict函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已适配

for i in model.named_sublayers():
module_name = i[0]
for module_name, module in model.named_sublayers():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要考虑开发lora merge

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已适配merge_model

for target_module in lora_config.target_modules:
if re.fullmatch(target_module, module_name):
self._find_and_replace_module(model, module_name, lora_config)
Expand Down
Loading