From 51b0e7c84e8925dc690bb6da8517288be2acc5c0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 16 Oct 2024 10:31:05 +0200 Subject: [PATCH] 2x speed up for qwen2-vl (#89) --- mlx_vlm/models/qwen2_vl/qwen2_vl.py | 6 +++--- mlx_vlm/models/qwen2_vl/vision.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 98f8b20..fe4ce6c 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -77,12 +77,12 @@ def _merge_input_ids_with_image_features( # Positions of tokens in input_ids, assuming batch size is 1 image_positions = input_ids == image_token_index - inputs_embeds = np.array(inputs_embeds.astype(mx.float32)) - inputs_embeds[image_positions] = image_features + image_indices = np.where(image_positions)[1].tolist() + inputs_embeds[:, image_indices, :] = image_features.astype(mx.float32) # TODO: Add video features - return mx.array(inputs_embeds) + return inputs_embeds def __call__( self, diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index 7b78447..29b4799 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -336,7 +336,7 @@ def __call__( # Concatenate the cu_seqlens for all items in the batch cu_seqlens = mx.concatenate(cu_seqlens) - cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32)) + cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0) cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) encoder_states = (hidden_states,) if output_hidden_states else None