Skip to content

Commit

Permalink
Add a lora dense layer (#1263)
Browse files Browse the repository at this point in the history
* Add a lora dense layer

Co-authored-by: Abheesht <[email protected]>

* address comments

* Fix merge conflict

* minor fix

* another einsum restriction

* Last doc nit from Ian

---------

Co-authored-by: Abheesht <[email protected]>
  • Loading branch information
mattdangerw and abheesht17 authored Oct 12, 2023
1 parent 8cab8ef commit 07e1cc2
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 2 deletions.
10 changes: 8 additions & 2 deletions keras_nlp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,23 @@ def pytest_collection_modifyitems(config, items):
not run_extra_large_tests,
reason="need --run_extra_large option to run",
)
skip_tf_only = pytest.mark.skipif(
tf_only = pytest.mark.skipif(
not backend_config.backend() == "tensorflow",
reason="tests only run on tf backend",
)
multi_backend_only = pytest.mark.skipif(
not backend_config.multi_backend(),
reason="tests only run on with multi-backend keras",
)
for item in items:
if "large" in item.keywords:
item.add_marker(skip_large)
if "extra_large" in item.keywords:
item.add_marker(skip_extra_large)
if "tf_only" in item.keywords:
item.add_marker(skip_tf_only)
item.add_marker(tf_only)
if "multi_backend_only" in item.keywords:
item.add_marker(multi_backend_only)


# Disable traceback filtering for quicker debugging of tests failures.
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CachedMultiHeadAttention,
)
from keras_nlp.layers.modeling.f_net_encoder import FNetEncoder
from keras_nlp.layers.modeling.lora_dense import LoraDense
from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding
Expand Down
234 changes: 234 additions & 0 deletions keras_nlp/layers/modeling/lora_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.backend import ops


def validate_einsum_equation(equation):
# For simplicity, we greatly restrict possible einsum equations. The final
# axis of the input must be the first axis of our kernel, and must not
# appear in our output.
left, right, output = re.split(",|->", equation)
valid = (
left[-1] == right[0]
and left[-1] not in output
and set(left[:-1]).isdisjoint(set(right[1:]))
)
if not valid:
raise ValueError(
"When passing a `EinsumDense` layer to a `LoraDense` layer, the "
"einsum `equation` must always have the form `*x,x*->*`, where "
"each `*` can be any sequence. Conceptually, the `equation` should "
"always represent a dense matmul on the last axis of the input. "
f"Received invalid equation `'{equation}'`."
)


@keras_nlp_export("keras_nlp.layers.LoraDense")
class LoraDense(keras.layers.Layer):
"""A LoRA adapter layer for a dense input layer.
This layer implements a low-rank decomposition of a dense transformation, as
described in [LoRA: Low-Rank Adaptation Of Large Language Models](https://arxiv.org/pdf/2106.09685.pdf)
This layer can be used to replace a dense layer with a layer whose
parameters are mostly frozen.
By default, this layer takes in an `inner_dense` layer, freezes its
parameters, and builds a low-rank decomposed update to sum with the original
`inner_dense` output. These update parameters can be merged back into the
`inner_dense` kernel by calling `merge_weights()`.
Args:
inner_dense: A `keras.layers.Dense` or `keras.layers.EinsumDense`.
The inner dense layer to freeze and wrap with the `LoraDense`
layer. Note that for `EinsumDense` layers, the einsum equation must
represent a dense transformation on the last axis of the input,
though adding new axes to the output (e.g. a multi-head axis) is
allowed.
rank: int. The inner rank of the decomposed dense transformation. The
lower this number, the fewer trainable parameters the layer will
have.
alpha: float. A constant value used for scaling the lora update. The
lora update to the original dense transformation will be scaled by
`alpha / rank`.
lora_a_initializer: The initializer to use for the inner projection
from layer inputs to the inner `rank` intermediate outputs.
freeze_kernel: If true, the kernel of the inner dense layer will have
`trainable` set to `False`.
freeze_bias: If true, the kernel of the inner dense layer will have
`trainable` set to `False`.
**kwargs: other keyword arguments.
Examples:
Wrap a `Dense` layer.
```python
batch_size, feature_size = 4, 16
rank = 4
inputs = np.random.uniform(size=(batch_size, feature_size))
inner_dense = keras.layers.Dense(feature_size)
lora_dense = keras_nlp.layers.LoraDense(inner_dense, rank=4)
# Output with inner dense begins equal.
assert np.allclose(inner_dense(inputs), lora_dense(inputs))
# Add some random updates to the lora parameters.
lora_dense.lora_a.assign(np.random.uniform(size=(feature_size, rank)))
lora_dense.lora_b.assign(np.random.uniform(size=(rank, feature_size)))
assert not np.allclose(inner_dense(inputs), lora_dense(inputs))
# Merge the lora dense and output
lora_dense.merge_weights()
assert np.allclose(inner_dense(inputs), lora_dense(inputs))
```
Wrap an `EinsumDense` layer with a multi-head projection.
```python
batch_size, sequence_length, feature_size = 4, 10, 16
num_heads = 2
rank = 4
inputs = np.random.uniform(size=(batch_size, sequence_length, feature_size))
inner_dense = keras.layers.EinsumDense(
"abc,cde->abde",
output_shape=(sequence_length, num_heads, feature_size // num_heads),
)
lora_dense = keras_nlp.layers.LoraDense(inner_dense, rank=4)
# Output shape (4, 10, 2, 8)
lora_dense(inputs)
```
"""

def __init__(
self,
inner_dense,
rank=8,
alpha=8.0,
lora_a_initializer="variance_scaling",
freeze_kernel=True,
freeze_bias=True,
**kwargs,
):
# Default to the same dtype as our inner layer.
if "dtype" not in kwargs:
kwargs["dtype"] = inner_dense.dtype_policy
super().__init__(**kwargs)

if not config.multi_backend():
raise ValueError(
"Lora only works with multi-backend Keras 3. Please set the "
"`KERAS_BACKEND` environment variable to use this API."
)

if isinstance(inner_dense, keras.layers.Dense):
self.inner_dense = inner_dense
elif isinstance(inner_dense, keras.layers.EinsumDense):
self.inner_dense = inner_dense
validate_einsum_equation(inner_dense.equation)
else:
raise ValueError(
"Only `Dense` and `EinsumDense` inner layers are supported. "
f"Received: inner_dense={inner_dense}"
)

self.rank = rank
self.alpha = alpha
self.scale = alpha / rank
self.freeze_kernel = freeze_kernel
self.freeze_bias = freeze_bias
self.lora_a_initializer = keras.initializers.get(lora_a_initializer)

if inner_dense.built:
self.build_from_config(inner_dense.get_build_config())

def build(self, inputs_shape):
if not self.inner_dense.built:
self.inner_dense.build(inputs_shape)

if self.freeze_kernel and self.inner_dense.kernel is not None:
self.inner_dense.kernel.trainable = False

if self.freeze_bias and self.inner_dense.bias is not None:
self.inner_dense.bias.trainable = False

input_dim = inputs_shape[-1]
self.lora_a = self.add_weight(
name="lora_a",
shape=(input_dim, self.rank),
initializer=self.lora_a_initializer,
)
kernel_shape = self.inner_dense.kernel.shape
self.lora_b = self.add_weight(
name="lora_b",
shape=(self.rank,) + kernel_shape[1:],
initializer="zeros",
)
self.built = True

def merge_weights(self):
"""Merge lora updates into the wrapped dense layer.
This function should only be called outside of any compiled context
(e.g. not during `fit()`, `predict()` or `evaluate()`). It will merge
the updates from the lora layers into the original dense layer, and
re-initialize the lora variables.
"""
if not self.built:
return

# Compute matmul of lora_a and lora_b to get a kernel sized update.
update = ops.tensordot(self.lora_a, self.lora_b, axes=([-1], [0]))
update = update * ops.cast(self.scale, update.dtype)
# Add lora updates back into the inner dense kernel.
self.inner_dense.kernel.assign_add(update)
# Re-initialize lora weights.
self.lora_a.assign(
self.lora_a_initializer(self.lora_a.shape, self.lora_a.dtype)
)
self.lora_b.assign(ops.zeros_like(self.lora_b))

def call(self, inputs):
original_output = self.inner_dense(inputs)
# Compute the low-rank intermediate output.
update = ops.matmul(inputs, self.lora_a)
# Use the matching dense computation for a Dense or EinsumDense.
if isinstance(self.inner_dense, keras.layers.Dense):
update = ops.matmul(update, self.lora_b)
else:
update = ops.einsum(self.inner_dense.equation, update, self.lora_b)
# Scale and sum the lora update with the original frozen output.
return original_output + update * ops.cast(self.scale, update.dtype)

@classmethod
def from_config(cls, config):
config["inner_dense"] = keras.layers.deserialize(config["inner_dense"])
return super().from_config(config)

def get_config(self):
config = super().get_config()
config.update(
{
"inner_dense": keras.layers.serialize(self.inner_dense),
"rank": self.rank,
"alpha": self.alpha,
"lora_a_initializer": keras.initializers.serialize(
self.lora_a_initializer
),
"freeze_kernel": self.freeze_kernel,
"freeze_bias": self.freeze_bias,
}
)
return config
Loading

0 comments on commit 07e1cc2

Please sign in to comment.