From 64927a0dbf883a5a444394ca1d3d37ee80d90693 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 29 Sep 2024 18:00:55 +0200 Subject: [PATCH] formatting --- mlx_vlm/models/idefics2/idefics2.py | 7 ++- mlx_vlm/models/llava/llava.py | 7 ++- mlx_vlm/models/llava_next/llava_next.py | 7 ++- .../models/multi_modality/multi_modality.py | 7 ++- mlx_vlm/models/pixtral/language.py | 1 + mlx_vlm/models/pixtral/pixtral.py | 11 ++--- mlx_vlm/models/pixtral/vision.py | 43 +++++++++++++------ 7 files changed, 62 insertions(+), 21 deletions(-) diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 53bb165..1c78365 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -251,7 +251,12 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index c8ae1a4..39aae4a 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -132,7 +132,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index d1e26db..878d7ca 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -136,7 +136,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index 062e90b..52a0bc9 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -360,7 +360,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/mlx_vlm/models/pixtral/language.py b/mlx_vlm/models/pixtral/language.py index e517be2..da8482c 100644 --- a/mlx_vlm/models/pixtral/language.py +++ b/mlx_vlm/models/pixtral/language.py @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn + from ..base import KVCache, create_attention_mask diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 836a3ca..b49397b 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -82,7 +82,6 @@ def get_input_embeddings( # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif self.vision_feature_select_strategy == "full": @@ -93,7 +92,6 @@ def get_input_embeddings( f"{self.vision_feature_select_strategy}" ) - # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) @@ -128,7 +126,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None, **kwargs + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( @@ -188,5 +191,3 @@ def transform_key(key): return key return {transform_key(k): v for k, v in weights.items()} - - diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index df3adcf..8243b71 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -75,7 +75,9 @@ def generate_block_attention_mask(patch_embeds_list, tensor): start, end = int(start), int(end) # Convert to integers for indexing causal_mask[start:end, start:end] = 0 - causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len)) + causal_mask = mx.broadcast_to( + causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len) + ) return causal_mask @@ -84,6 +86,7 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return mx.concatenate((-x2, x1), axis=-1) + def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = mx.expand_dims(cos, axis=unsqueeze_dim) sin = mx.expand_dims(sin, axis=unsqueeze_dim) @@ -91,6 +94,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + class Attention(nn.Module): def __init__( self, @@ -140,7 +144,6 @@ def __call__(self, queries, keys, values, position_embeddings, mask=None): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - cos, sin = position_embeddings queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0) @@ -181,7 +184,12 @@ def __init__(self, config: VisionConfig): self.feed_forward = MLP(config) self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps) - def __call__(self, x: mx.array, position_embeddings: mx.array, mask: Optional[mx.array] = None) -> mx.array: + def __call__( + self, + x: mx.array, + position_embeddings: mx.array, + mask: Optional[mx.array] = None, + ) -> mx.array: y = self.attention_norm(x) y = self.attention(y, y, y, position_embeddings, mask) x = x + y @@ -196,22 +204,27 @@ def __init__(self, config: VisionConfig): self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] -class PixtralRotaryEmbedding(): +class PixtralRotaryEmbedding: def __init__(self, config): self.dim = config.head_dim self.base = config.rope_theta max_patches_per_side = config.image_size // config.patch_size - freqs = 1.0 / (self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)) + freqs = 1.0 / ( + self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim) + ) h = mx.arange(max_patches_per_side) w = mx.arange(max_patches_per_side) freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32) freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32) - inv_freq = mx.concatenate([ - mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)), - mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)) - ], axis=-1).reshape(-1, self.dim // 2) + inv_freq = mx.concatenate( + [ + mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)), + mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)), + ], + axis=-1, + ).reshape(-1, self.dim // 2) self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1) @@ -222,6 +235,7 @@ def __call__(self, x, position_ids): sin = mx.sin(emb) return cos.astype(x.dtype), sin.astype(x.dtype) + class ClipVisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() @@ -245,12 +259,15 @@ def __call__( B, H, W, C = x.shape patch_embeds_list = [self.patch_conv(img[None, :]) for img in x] - patch_embeds = mx.concatenate([p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1) + patch_embeds = mx.concatenate( + [p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1 + ) patch_embeds = self.ln_pre(patch_embeds) position_ids = position_ids_in_meshgrid( - patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + patch_embeds_list, + max_width=self.config.image_size // self.config.patch_size, ) position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) @@ -262,7 +279,9 @@ def __call__( encoder_states = (patch_embeds,) if output_hidden_states else None for l in self.transformer.layers: - patch_embeds = l(patch_embeds, mask=mask, position_embeddings=position_embedding) + patch_embeds = l( + patch_embeds, mask=mask, position_embeddings=position_embedding + ) if output_hidden_states: encoder_states = encoder_states + (patch_embeds,)