Skip to content

Commit

Permalink
Add mamba tests [#75]
Browse files Browse the repository at this point in the history
Add tests for mamba and mamba2. Disabled cache because the tests yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' otherwise.
  • Loading branch information
ddilbazTT committed Jan 20, 2025
1 parent 1971a82 commit 3653358
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/run-model-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
matrix:
build: [
{runs-on: n150, name: "run1", test_names: "stable_diffusion, Qwen, MobileNetV2, clip, flan_t5, mlpmixer, resnet, vilt, albert, codegen, glpn_kitti, mnist, resnet50, RMBG, unet_carvana, mgp-str-base, musicgen_small, segformer, torchvision, yolos"},
{runs-on: n150, name: "run2", test_names: "t5, whisper, autoencoder_conv, deit, gpt2, mobilenet_ssd, roberta, timm, xglm, autoencoder_linear, detr, beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, gpt_neo"},
{runs-on: n150, name: "run2", test_names: "t5, whisper, autoencoder_conv, deit, gpt2, mobilenet_ssd, roberta, timm, xglm, autoencoder_linear, detr, beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, gpt_neo, mamba"},
]
runs-on:
- ${{ matrix.build.runs-on }}
Expand Down
66 changes: 66 additions & 0 deletions tests/models/mamba/test_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
# Reference: https://huggingface.co/state-spaces/mamba-2.8b-hf

from transformers import MambaForCausalLM, AutoTokenizer, GenerationConfig
import pytest
from tests.utils import ModelTester
import torch
import types


class ThisTester(ModelTester):
def _load_model(self):
model = MambaForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16
)

model.generate = lambda **kwargs: type(model).generate(
model, **{**kwargs, "use_cache": False}
)

self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16
)

return model.generate

def _load_inputs(self):
prompt = "Hey how are you doing?"
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
generation_config = GenerationConfig(max_new_tokens=10, use_cache=False)
arguments = {
"input_ids": input_ids,
"generation_config": generation_config,
"use_cache": False,
}
return arguments

def set_model_eval(self, model):
return model


@pytest.mark.parametrize(
"mode",
["eval"],
)
@pytest.mark.parametrize(
"model_name",
[
"state-spaces/mamba-790m-hf",
"state-spaces/mamba-2.8b-hf",
"state-spaces/mamba-1.4b-hf",
"state-spaces/mamba-370m-hf",
],
)
def test_mamba(record_property, mode, model_name):
record_property("model_name", model_name)
record_property("mode", mode)
tester = ThisTester(model_name, mode)
results = tester.test_model()
if mode == "eval":
gen_text = tester.tokenizer.batch_decode(results)
print("Generated text: ", gen_text)

record_property("torch_ttnn", (tester, results))

0 comments on commit 3653358

Please sign in to comment.