diff --git a/gemma/layers.py b/gemma/layers.py index f6179cb..189dad0 100644 --- a/gemma/layers.py +++ b/gemma/layers.py @@ -22,10 +22,11 @@ class Einsum(nn.Module): """Einsum is a convenience module for parameterized tensor multiplication.""" shape: tuple[int, ...] + weight_name: str = 'w' @nn.compact def __call__(self, eqn: str, x: jax.Array) -> jax.Array: - w = self.param('w', nn.initializers.normal(), self.shape) + w = self.param(self.weight_name, nn.initializers.normal(), self.shape) return jnp.einsum(eqn, x, w) diff --git a/gemma/modules.py b/gemma/modules.py index 7a1fa67..9f8b399 100644 --- a/gemma/modules.py +++ b/gemma/modules.py @@ -248,32 +248,31 @@ def __call__(self, x): # Some versions use an alternate parameter ordering that # transposes hidden_dim and features. if self.transpose_gating_einsum: - w_gating = self.param( - 'gating_einsum', - nn.initializers.normal(), - ((2, self.hidden_dim, self.features)), + eq = '...F,NHF->...,NH' + gating = layers.Einsum( + shape=(2, self.hidden_dim, self.features), + weight_name='gating_einsum', ) - w_gating = w_gating.transpose((0, 2, 1)) else: - w_gating = self.param( - 'gating_einsum', - nn.initializers.normal(), - ((2, self.features, self.hidden_dim)), + eq = '...F,NFH->...NH' + gating = layers.Einsum( + shape=(2, self.features, self.hidden_dim), + weight_name='gating_einsum', ) - ff_gate = jnp.dot(x, w_gating[0]) - gate_value = nn.gelu(ff_gate) - # Up projection - ff1 = jnp.dot(x, w_gating[1]) - activations = gate_value * ff1 + # Use the same scope for backwards compatibility with existing checkpoints + # created before using `layers.Einsum` here. + nn.share_scope(self, gating) + gate = gating(eq, x) + activations = nn.gelu(gate[..., 0, :]) * gate[..., 1, :] # Down projection - w_linear = self.param( - 'linear', - nn.initializers.zeros_init(), - (self.hidden_dim, self.features), + linear = layers.Einsum( + shape=(self.hidden_dim, self.features), + weight_name='linear', ) - outputs = jnp.dot(activations, w_linear) + nn.share_scope(self, linear) + outputs = linear('...H,HF->...F', activations) return outputs