Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 95 additions & 14 deletions networks/sdxl_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
import concurrent.futures
import re
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
Expand Down Expand Up @@ -110,15 +111,56 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")

if method == "LoRA":
def convert_diffusers_labels_to_unet(name: str) -> str:
if "_attentions_" not in name:
return name # attention-schema names only

# Normalize stage tokens everywhere, not only for attention names
s = (name.replace("unet_up", "unet_output")
.replace("unet_down", "unet_input")
.replace("unet_mid", "unet_middle"))

# Middle: ...middle_block_attentions_X_* -> ...middle_block_{X+1}_*
if "unet_middle" in s and "middle_block_attentions_" in s:
return re.sub(
r"middle_block_attentions_(\d+)_",
lambda m: f"middle_block_{int(m.group(1)) + 1}_",
s,
)

left, right = s.split("_attentions_", 1)
L = left.split("_") # lora_unet_[input|output]_blocks_{X}
stage = L[2]
X = int(L[-1]) # down/up block index in "wrong"
Y_str, *rest = right.split("_")
Y = int(Y_str) # attentions index in "wrong"
# Remainder includes transformer/proj tail (kept as-is)
# Map to "right" indices:
if stage == "input": # from "down"
if X == 1: i = 4 + Y # → input_blocks_{4|5}_1
elif X == 2: i = 7 + Y # → input_blocks_{7|8}_1
else: return s # no attentions elsewhere
j = 1
elif stage == "output": # from "up"
i = 3 * X + Y # db=0→0..2, db=1→3..5, db=2→6..8
j = 1
else:
return s

L[-1] = str(i)
return "_".join(L + [str(j)] + rest)

for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
key_base = key[: key.index("lora_down")]
up_key = key_base + "lora_up.weight"
dora_key = key_base + "dora_scale"
alpha_key = key_base + "alpha"

# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
module_name = ".".join(convert_diffusers_labels_to_unet(key).split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {module_name}, from({key})")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
Expand All @@ -127,7 +169,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
up_weight = lora_sd[up_key]

dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
alpha = lora_sd.get(alpha_key, 1.0)
scale = alpha / dim

if lbw:
Expand All @@ -138,23 +180,62 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,

# W <- W + U * D
weight = module.weight
lora_diff = None
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
lora_diff = (up_weight @ down_weight)
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
lora_diff = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
lora_diff = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale

dora_scale = lora_sd.get(dora_key, None)

# Algorithm/math taken from reForge
if dora_scale is None:
# -------- Plain LoRA (mirror your original math) --------
# W <- W + ratio * (lora_diff * scale)
weight = weight + (ratio * (lora_diff * scale)).to(dtype=weight.dtype, device=weight.device)

else:
# -------- DoRA (literal reForge semantics) --------
# cast dora_scale like reForge does (to intermediate, then we use weight.dtype for ops)
ds = dora_scale.to(device=weight.device, dtype=merge_dtype)

# lora_diff gets 'alpha' (scale == alpha/rank) BEFORE magnitude; strength applied AFTER magnitude
lora_diff_scaled = (lora_diff * scale).to(dtype=weight.dtype, device=weight.device)

# weight_calc = weight + function(lora_diff_scaled); function is identity here
weight_calc = weight + lora_diff_scaled

wd_on_output_axis = (ds.shape[0] == weight_calc.shape[0])
if wd_on_output_axis:
# per-OUT norm taken from ORIGINAL weight (matches reForge)
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
)
else:
# per-IN norm from weight_calc^T (matches reForge)
wc = weight_calc.transpose(0, 1)
weight_norm = (
wc.reshape(wc.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(wc.shape[0], *[1] * (wc.dim() - 1))
.transpose(0, 1)
)

weight_norm = weight_norm + torch.finfo(weight.dtype).eps

# Apply magnitude: weight_calc *= (dora_scale / weight_norm)
# (Do NOT reshape ds; rely on its stored shape for broadcasting)
weight_calc = weight_calc * (ds.to(dtype=weight.dtype) / weight_norm)
weight = torch.lerp(weight, weight_calc, float(ratio))

module.weight = torch.nn.Parameter(weight)

Expand Down