Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Features]Add lorapro #9729

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,9 @@
if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2":
optimizer_kwargs["multi_precision"] = True

if isinstance(self.model, LoRAModel) and self.args.use_lorapro:
optimizer_kwargs["scaling_factor"] = self.model.lora_config.scaling

Check warning on line 1819 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1819

Added line #L1819 was not covered by tests

self.optimizer = optimizer_cls(
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
Expand Down Expand Up @@ -1950,7 +1953,14 @@
"beta2": args.adam_beta2,
"epsilon": args.adam_epsilon,
}
if args.optim == OptimizerNames.ADAMW:
if args.use_lorapro:
# from ..utils import AdamWMini
# optimizer_cls = AdamWMini
from ..utils import LoRAPro

Check warning on line 1959 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1959

Added line #L1959 was not covered by tests

optimizer_cls = LoRAPro
optimizer_kwargs.update(adam_kwargs)

Check warning on line 1962 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1961-L1962

Added lines #L1961 - L1962 were not covered by tests
elif args.optim == OptimizerNames.ADAMW:
from paddle.optimizer import AdamW

optimizer_cls = AdamW
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ class TrainingArguments:
default=False,
metadata={"help": "When performing evaluation and predictions, only returns the loss."},
)
use_lorapro: bool = field(
default=False,
metadata={"help": "When use lora-pro"},
)

per_device_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU core/CPU for training."})
per_device_eval_batch_size: int = field(
Expand Down
127 changes: 127 additions & 0 deletions paddlenlp/utils/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

import paddle
import paddle.autograd as imperative_base
from paddle import pir
from paddle.base import core, framework
from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode
Expand Down Expand Up @@ -149,3 +152,127 @@
beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]
# 看看怎么更新
return


class LoRAPro(AdamW):
def __init__(
self,
learning_rate: float = 0.001,
beta1: float = 0.9,
beta2: float = 0.999,
epsilon: float = 1e-8,
parameters=None,
weight_decay: float = 0.01,
lr_ratio=None,
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode: bool = False,
multi_precision: bool = False,
amsgrad: bool = False,
name=None,
scaling_factor: float = 1.0,
) -> None:
super().__init__(

Check warning on line 175 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L175

Added line #L175 was not covered by tests
learning_rate,
beta1,
beta2,
epsilon,
parameters,
weight_decay,
lr_ratio,
apply_decay_param_fun,
grad_clip,
lazy_mode,
multi_precision,
amsgrad,
name,
)
self.scaling_factor = scaling_factor

Check warning on line 190 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L190

Added line #L190 was not covered by tests

@imperative_base.no_grad()
@framework.non_static_only
def step(self) -> None:
"""
Execute the optimizer and update parameters once.

Returns:
None

Examples:
.. code-block:: python

>>> import paddle

>>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
>>> linear = paddle.nn.Linear(13, 5)
>>> # This can be any optimizer supported by dygraph.
>>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
... parameters = linear.parameters())
>>> out = linear(a)
>>> out.backward()
>>> adam.step()
>>> adam.clear_grad()
"""
if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
scaling_factor = self.scaling_factor
if not isinstance(self._param_groups[0], dict):
params_grads = []
lora_num = len(self._param_groups) // 2
for i in range(lora_num):

Check warning on line 223 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L216-L223

Added lines #L216 - L223 were not covered by tests
# 先转置
A = self._param_groups[2 * i].detach().T
B = self._param_groups[2 * i + 1].detach().T
grad_A_orin = self._param_groups[2 * i]._grad_ivar().T
grad_B_orin = self._param_groups[2 * i + 1]._grad_ivar().T

Check warning on line 228 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L225-L228

Added lines #L225 - L228 were not covered by tests

# 中间与torch保持一致
delta = 1e-8
AA_T = A @ A.T
B_TB = B.T @ B
AA_T_inv = paddle.linalg.pinv(AA_T + delta * paddle.eye(A.shape[0]))
B_TB_inv = paddle.linalg.pinv(B_TB + delta * paddle.eye(A.shape[0]))

Check warning on line 235 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L231-L235

Added lines #L231 - L235 were not covered by tests

X = paddle.zeros((B_TB_inv.shape[0], B_TB_inv.shape[0])).cast(B.dtype)

Check warning on line 237 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L237

Added line #L237 was not covered by tests

grad_A = (1 / scaling_factor**2) * B_TB_inv @ grad_A_orin + X @ A
grad_B = (1 / scaling_factor**2) * (

Check warning on line 240 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L239-L240

Added lines #L239 - L240 were not covered by tests
(paddle.eye(B.shape[0]) - B @ B_TB_inv @ B.T) @ grad_B_orin @ AA_T_inv
) - B @ X

# 最后转置回来
self._param_groups[2 * i]._grad_ivar()[:] = grad_A.T
self._param_groups[2 * i + 1]._grad_ivar()[:] = grad_B.T

Check warning on line 246 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L245-L246

Added lines #L245 - L246 were not covered by tests

for param in self._param_groups:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))

Check warning on line 253 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L248-L253

Added lines #L248 - L253 were not covered by tests

self._apply_optimize(

Check warning on line 255 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L255

Added line #L255 was not covered by tests
loss=None,
startup_program=None,
params_grads=params_grads,
param_group_idx=0,
)

else:
# optimize parameters in groups
for idx, param_group in enumerate(self._param_groups):
params_grads = defaultdict(lambda: [])
for param in param_group["params"]:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads["params"].append((param, grad_var))
params_grads.update({k: v for k, v in param_group.items() if k != "params"})
self._apply_optimize(

Check warning on line 273 in paddlenlp/utils/optimizer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/optimizer.py#L264-L273

Added lines #L264 - L273 were not covered by tests
loss=None,
startup_program=None,
params_grads=params_grads,
param_group_idx=idx,
)