Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a lora dense layer #1263

Merged
merged 6 commits into from
Oct 12, 2023
Merged

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Oct 4, 2023

#1264 shows how this will eventually fit together.

@mattdangerw
Copy link
Member Author

Note that this is Keras 3/Keras Core only because that library will allow you to set trainable on individual parameters (tf.keras does not). It doesn't seem worth the effort to build backwards compat here, this can be a strictly forward looking Keras 3 API.

not backend_config.backend() == "tensorflow",
reason="tests only run on tf backend",
)
multi_backend_only = pytest.mark.skipif(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe keras_3_only?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm on second thought i'll leave it as multi-backend for now. because it can run with keras 3 or keras core for the time being. and we will have zero coverage for the actual keras 3 code path till we get that pip.

keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
keras_nlp/layers/modeling/lora_dense.py Show resolved Hide resolved
inner_dense,
rank=8,
alpha=32.0,
lora_a_initializer="variance_scaling",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to include "lora" as a prefix for this variable name, given that it's already a LoraDense layer/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay, because the two dense layers in LoRA are called "lora_A" and "lora_B" (in the official code). Calling this a_initializer would look weird :P.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wanted to avoid layer.a and layer.b. We could come up with our own names, layer.inner_kernel_update and layer.outer_kernel_update, but I suspect lora_a and lora_b will be more recognizable to people.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine to me.

keras_nlp/layers/modeling/lora_dense_test.py Show resolved Hide resolved
self,
inner_dense,
rank=8,
alpha=32.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, alpha = rank. I know the guide has alpha = 32., and rank = 4, but it was an oversight on my part. The authors state that they went with alpha = rank, and tuned the learning rate.

Source: Section 4.1 in https://openreview.net/forum?id=nZeVKeeFYf9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Seems like we could

  • Default alpha=8., so it's obvious the type.
  • Default alpha=None and assign float(rank) if unset.
  • Just expose scale directly, and default scale=1..
  • Don't expose anything.

Maybe just alpha=8.? Simple and what people would expect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach LG

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was passing through, left some NITs

keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
keras_nlp/layers/modeling/lora_dense.py Outdated Show resolved Hide resolved
inner_dense,
rank=8,
alpha=32.0,
lora_a_initializer="variance_scaling",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay, because the two dense layers in LoRA are called "lora_A" and "lora_B" (in the official code). Calling this a_initializer would look weird :P.

@mattdangerw
Copy link
Member Author

/gcburn

@mattdangerw
Copy link
Member Author

I think this is ready for another round!

@mattdangerw
Copy link
Member Author

/gcbrun

@mattdangerw
Copy link
Member Author

/gcbrun

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

lora_a_initializer: The initializer to use for the inner projection
from layer inputs to the inner `rank` intermediate outputs.
freeze_kernel: If true, the kernel of the inner dense layer will have
`trainable` set to False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we backtick False in these contexts?

@mattdangerw mattdangerw merged commit 07e1cc2 into keras-team:master Oct 12, 2023
3 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants