Skip to content

Commit

Permalink
Fix make_mlp_unpacked with lora (#1209)
Browse files Browse the repository at this point in the history
- The check was done on the `mlp` module instead of the `gate_up_linear`
module so it always returned False.
- Need to access lora_B from `gate_up_linear`. 
- Code referred to k_proj and v_proj
  • Loading branch information
jambayk authored Feb 3, 2025
1 parent 728cc93 commit 61e8efb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def make_mlp_unpacked(self, layer_id, mlp, root_input):
# Return early if there's nothing to unpack
return

if hasattr(mlp, "base_layer"):
if hasattr(gate_up_linear, "base_layer"):
# For LoRA packed `MatMul`
return self.make_mlp_unpacked_lora(layer_id, mlp, root_input)
else:
Expand All @@ -1701,7 +1701,7 @@ def make_mlp_unpacked_lora(self, layer_id, mlp, root_input):
up_proj.bias = None if gate_up_linear.bias is None else torch.nn.Parameter(gate_up_linear.bias[self.intermediate_size :], requires_grad=False)

# Create GateProj/UpProj lora_B layers
lora_B = mlp.lora_B.default
lora_B = gate_up_linear.lora_B.default

gate_proj_lora_B = torch.nn.Linear(in_features=self.hidden_size, out_features=self.intermediate_size)
gate_proj_lora_B.weight = torch.nn.Parameter(lora_B.weight[ : self.intermediate_size, :], requires_grad=False)
Expand All @@ -1712,12 +1712,12 @@ def make_mlp_unpacked_lora(self, layer_id, mlp, root_input):
up_proj_lora_B.bias = None if lora_B.bias is None else torch.nn.Parameter(lora_B.bias[self.intermediate_size :], requires_grad=False)

# Create GateProj/UpProj LoRA layers
mlp.gate_proj = LoraLayer(q_proj)
mlp.gate_proj = LoraLayer(gate_proj)
mlp.gate_proj.lora_A = gate_up_linear.lora_A
mlp.gate_proj.lora_B.default = gate_proj_lora_B
mlp.gate_proj.scaling = gate_up_linear.scaling

mlp.up_proj = LoraLayer(k_proj)
mlp.up_proj = LoraLayer(up_proj)
mlp.up_proj.lora_A = gate_up_linear.lora_A
mlp.up_proj.lora_B.default = up_proj_lora_B
mlp.up_proj.scaling = gate_up_linear.scaling
Expand Down

0 comments on commit 61e8efb

Please sign in to comment.