From 9421149304d73b62523962227855a91c4e4c0256 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 4 Jul 2024 16:07:51 +0200 Subject: [PATCH] Add Dolphin-vision and bunny (#50) * add Dolphin-vision * add bunny * bump version --- .../{nanoLlava => llava_bunny}/__init__.py | 2 +- .../{nanoLlava => llava_bunny}/language.py | 12 +++- .../llava_bunny.py} | 0 .../{nanoLlava => llava_bunny}/vision.py | 0 mlx_vlm/prompt_utils.py | 4 +- mlx_vlm/tests/test_models.py | 12 ++-- mlx_vlm/utils.py | 70 +++++++++---------- mlx_vlm/version.py | 2 +- 8 files changed, 51 insertions(+), 51 deletions(-) rename mlx_vlm/models/{nanoLlava => llava_bunny}/__init__.py (81%) rename mlx_vlm/models/{nanoLlava => llava_bunny}/language.py (96%) rename mlx_vlm/models/{nanoLlava/nanoLlava.py => llava_bunny/llava_bunny.py} (100%) rename mlx_vlm/models/{nanoLlava => llava_bunny}/vision.py (100%) diff --git a/mlx_vlm/models/nanoLlava/__init__.py b/mlx_vlm/models/llava_bunny/__init__.py similarity index 81% rename from mlx_vlm/models/nanoLlava/__init__.py rename to mlx_vlm/models/llava_bunny/__init__.py index 55485d8..02abf58 100644 --- a/mlx_vlm/models/nanoLlava/__init__.py +++ b/mlx_vlm/models/llava_bunny/__init__.py @@ -1,4 +1,4 @@ -from .nanoLlava import ( +from .llava_bunny import ( ImageProcessor, LanguageModel, Model, diff --git a/mlx_vlm/models/nanoLlava/language.py b/mlx_vlm/models/llava_bunny/language.py similarity index 96% rename from mlx_vlm/models/nanoLlava/language.py rename to mlx_vlm/models/llava_bunny/language.py index 1b931f9..b82c0eb 100644 --- a/mlx_vlm/models/nanoLlava/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -15,6 +15,7 @@ class TextConfig: num_attention_heads: int rms_norm_eps: float vocab_size: int + attention_bias: bool = True num_key_value_heads: int = None rope_theta: float = 1000000 rope_traditional: bool = False @@ -55,9 +56,14 @@ def __init__(self, args: TextConfig): head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( diff --git a/mlx_vlm/models/nanoLlava/nanoLlava.py b/mlx_vlm/models/llava_bunny/llava_bunny.py similarity index 100% rename from mlx_vlm/models/nanoLlava/nanoLlava.py rename to mlx_vlm/models/llava_bunny/llava_bunny.py diff --git a/mlx_vlm/models/nanoLlava/vision.py b/mlx_vlm/models/llava_bunny/vision.py similarity index 100% rename from mlx_vlm/models/nanoLlava/vision.py rename to mlx_vlm/models/llava_bunny/vision.py diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 3d3f08b..3374689 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -3,7 +3,7 @@ def get_message_json(model_name, prompt): Get the appropriate JSON message based on the specified model. Args: - model_name (str): The model for which to generate the message. Options: 'Idefics 2', 'nanollava', 'llava'. + model_name (str): The model for which to generate the message. prompt (str): The text prompt to be included in the message. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). @@ -16,7 +16,7 @@ def get_message_json(model_name, prompt): "role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}], } - elif model_name.lower() in ["llava-qwen2", "llava", "llava_next"]: + elif model_name.lower() in ["llava-qwen2", "llava", "llava_next", "bunny-llama"]: message = {"role": "user", "content": f"\n{prompt}"} elif model_name.lower() == "phi3_v": message = {"role": "user", "content": f"<|image_1|>\n{prompt}"} diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 58e0c7e..856cf26 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -70,10 +70,10 @@ def vision_test_runner( hidden_states[vision_feature_layer][-1][-1].shape, (vision_hidden_size,) ) - def test_nanoLlava(self): - from mlx_vlm.models import nanoLlava + def test_llava_bunny(self): + from mlx_vlm.models import llava_bunny - text_config = nanoLlava.TextConfig( + text_config = llava_bunny.TextConfig( model_type="qwen2", hidden_size=4096, num_hidden_layers=32, @@ -87,7 +87,7 @@ def test_nanoLlava(self): rope_scaling=None, ) - vision_config = nanoLlava.VisionConfig( + vision_config = llava_bunny.VisionConfig( model_type="siglip_vision_model", num_hidden_layers=27, hidden_size=1152, @@ -101,7 +101,7 @@ def test_nanoLlava(self): layer_norm_eps=1e-6, ) - args = nanoLlava.ModelConfig( + args = llava_bunny.ModelConfig( text_config=text_config, vision_config=vision_config, model_type="llava-qwen2", @@ -118,7 +118,7 @@ def test_nanoLlava(self): vocab_size=151936, ) - model = nanoLlava.Model(args) + model = llava_bunny.Model(args) self.language_test_runner( model.language_model, diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index dd6d80a..441ee37 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -29,9 +29,7 @@ from .tokenizer_utils import load_tokenizer # Constants -MODEL_REMAPPING = { - "llava-qwen2": "nanoLlava", -} +MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"} MAX_FILE_SIZE_GB = 5 @@ -150,12 +148,15 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: model_class, model_type = get_model_and_args(config=config) - if model_type == "nanoLlava": + if model_type == "llava_bunny": vision_config = AutoConfig.from_pretrained(config["mm_vision_tower"]) text_config = AutoConfig.from_pretrained(config["language_model"]) vision_config = vision_config.to_dict() text_config = text_config.to_dict() - config["vision_config"] = vision_config["vision_config"] + config["vision_config"] = { + **vision_config["vision_config"], + **config.get("vision_config", {}), + } config["text_config"] = text_config if model_type == "idefics2": config = AutoConfig.from_pretrained(model_path).to_dict() @@ -194,7 +195,6 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: weights = model_class.LanguageModel(model_config.text_config).sanitize( weights=weights ) - if (quantization := config.get("quantization", None)) is not None: # Handle legacy models which may not have everything quantized class_predicate = ( @@ -502,10 +502,8 @@ def quantize_model( divisor = 64 if any(vision_intermediate_size % size != 0 for size in [64, 128]): for name, module in model.named_modules(): - if ( - isinstance(module, nn.Linear) - or isinstance(module, nn.Embedding) - and ("vision_model" in name or "vision_tower" in name) + if isinstance(module, nn.Linear) and ( + "vision_model" in name or "vision_tower" in name ): out_features, in_features = module.weight.shape @@ -520,34 +518,30 @@ def quantize_model( if in_features % divisor != 0 else in_features ) - if ( - out_features == vision_intermediate_size - or in_features == vision_intermediate_size - ): - - # If padding is needed, proceed - if ( - new_out_features != out_features - or new_in_features != in_features - ): - # Create new weight and bias tensors - new_weight = mx.zeros((new_out_features, new_in_features)) - new_bias = mx.zeros((new_out_features)) - - # Copy existing weights and biases to the new tensors - new_weight[:out_features, :in_features] = module.weight - module.weight = new_weight - - if hasattr(module, "bias"): - new_bias[:out_features] = module.bias - module.bias = new_bias - - if "vision_config" in quantized_config: - quantized_config["vision_config"]["intermediate_size"] = ( - ((vision_intermediate_size // divisor) + 1) * divisor - if vision_intermediate_size % divisor != 0 - else vision_intermediate_size - ) + + # If padding is needed, proceed + if new_out_features != out_features or new_in_features != in_features: + # Create new weight and bias tensors + new_weight = mx.zeros((new_out_features, new_in_features)) + new_bias = mx.zeros((new_out_features)) + + # Copy existing weights and biases to the new tensors + new_weight[:out_features, :in_features] = module.weight + module.weight = new_weight + + if hasattr(module, "bias"): + new_bias[:out_features] = module.bias + module.bias = new_bias + + # Ensure vision_config exists in quantized_config + quantized_config.setdefault("vision_config", {}) + + # Update intermediate_size + quantized_config["vision_config"]["intermediate_size"] = ( + ((vision_intermediate_size // divisor) + 1) * divisor + if vision_intermediate_size % divisor != 0 + else vision_intermediate_size + ) nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index 9b36b86..b2f0155 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.0.10" +__version__ = "0.0.11"