Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Sep 29, 2024
1 parent e2e300d commit 64927a0
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 21 deletions.
7 changes: 6 additions & 1 deletion mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,12 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
Expand Down
7 changes: 6 additions & 1 deletion mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
Expand Down
7 changes: 6 additions & 1 deletion mlx_vlm/models/llava_next/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):

input_embddings = self.get_input_embeddings(input_ids, pixel_values)
Expand Down
7 changes: 6 additions & 1 deletion mlx_vlm/models/multi_modality/multi_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):

input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
Expand Down
1 change: 1 addition & 0 deletions mlx_vlm/models/pixtral/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask


Expand Down
11 changes: 6 additions & 5 deletions mlx_vlm/models/pixtral/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def get_input_embeddings(
# Select the hidden states from the desired layer
selected_image_feature = hidden_states[self.vision_feature_layer]


if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
Expand All @@ -93,7 +92,6 @@ def get_input_embeddings(
f"{self.vision_feature_select_strategy}"
)


# Pass image features through the multi-modal projector
image_features = self.multi_modal_projector(selected_image_feature)

Expand Down Expand Up @@ -128,7 +126,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
Expand Down Expand Up @@ -188,5 +191,3 @@ def transform_key(key):
return key

return {transform_key(k): v for k, v in weights.items()}


43 changes: 31 additions & 12 deletions mlx_vlm/models/pixtral/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def generate_block_attention_mask(patch_embeds_list, tensor):
start, end = int(start), int(end) # Convert to integers for indexing
causal_mask[start:end, start:end] = 0

causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len))
causal_mask = mx.broadcast_to(
causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len)
)
return causal_mask


Expand All @@ -84,13 +86,15 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate((-x2, x1), axis=-1)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = mx.expand_dims(cos, axis=unsqueeze_dim)
sin = mx.expand_dims(sin, axis=unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class Attention(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -140,7 +144,6 @@ def __call__(self, queries, keys, values, position_embeddings, mask=None):
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)


cos, sin = position_embeddings
queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0)

Expand Down Expand Up @@ -181,7 +184,12 @@ def __init__(self, config: VisionConfig):
self.feed_forward = MLP(config)
self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)

def __call__(self, x: mx.array, position_embeddings: mx.array, mask: Optional[mx.array] = None) -> mx.array:
def __call__(
self,
x: mx.array,
position_embeddings: mx.array,
mask: Optional[mx.array] = None,
) -> mx.array:
y = self.attention_norm(x)
y = self.attention(y, y, y, position_embeddings, mask)
x = x + y
Expand All @@ -196,22 +204,27 @@ def __init__(self, config: VisionConfig):
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]


class PixtralRotaryEmbedding():
class PixtralRotaryEmbedding:
def __init__(self, config):
self.dim = config.head_dim
self.base = config.rope_theta
max_patches_per_side = config.image_size // config.patch_size
freqs = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim))
freqs = 1.0 / (
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
)

h = mx.arange(max_patches_per_side)
w = mx.arange(max_patches_per_side)

freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32)
freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32)
inv_freq = mx.concatenate([
mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)),
mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1))
], axis=-1).reshape(-1, self.dim // 2)
inv_freq = mx.concatenate(
[
mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)),
mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)),
],
axis=-1,
).reshape(-1, self.dim // 2)

self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1)

Expand All @@ -222,6 +235,7 @@ def __call__(self, x, position_ids):
sin = mx.sin(emb)
return cos.astype(x.dtype), sin.astype(x.dtype)


class ClipVisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
Expand All @@ -245,12 +259,15 @@ def __call__(
B, H, W, C = x.shape
patch_embeds_list = [self.patch_conv(img[None, :]) for img in x]

patch_embeds = mx.concatenate([p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1)
patch_embeds = mx.concatenate(
[p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1
)

patch_embeds = self.ln_pre(patch_embeds)

position_ids = position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size,
)

position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
Expand All @@ -262,7 +279,9 @@ def __call__(
encoder_states = (patch_embeds,) if output_hidden_states else None

for l in self.transformer.layers:
patch_embeds = l(patch_embeds, mask=mask, position_embeddings=position_embedding)
patch_embeds = l(
patch_embeds, mask=mask, position_embeddings=position_embedding
)
if output_hidden_states:
encoder_states = encoder_states + (patch_embeds,)

Expand Down

0 comments on commit 64927a0

Please sign in to comment.