Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for mamba and mamba2 varients #75

Open
AleksKnezevic opened this issue Nov 27, 2024 · 1 comment
Open

Add tests for mamba and mamba2 varients #75

AleksKnezevic opened this issue Nov 27, 2024 · 1 comment
Assignees

Comments

@AleksKnezevic
Copy link
Contributor

Please add support/tests for mamba and mamba2 variants from here

@ddilbazTT
Copy link
Contributor

I am working on this. Currently, the test is:

# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
# Reference: https://huggingface.co/state-spaces/mamba-2.8b-hf

from transformers import MambaForCausalLM, AutoTokenizer, GenerationConfig
import pytest
from tests.utils import ModelTester
import torch

class ThisTester(ModelTester):
    def _load_model(self):
        model = MambaForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.bfloat16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, torch_dtype=torch.bfloat16
        )
        return model.generate

    def _load_inputs(self):
        prompt = "Hey how are you doing?"
        input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
        generation_config = GenerationConfig(max_new_tokens=10)
        arguments = {"input_ids": input_ids, "generation_config": generation_config}
        return arguments

    def set_model_eval(self, model):
        return model


@pytest.mark.parametrize(
    "mode",
    ["eval"],
)
@pytest.mark.parametrize(
    "model_name",
    [
        "state-spaces/mamba-790m-hf",
    ],
)
def test_mamba(record_property, mode, model_name):
    record_property("model_name", model_name)
    record_property("mode", mode)

    tester = ThisTester(model_name, mode)
    results = tester.test_model()
    if mode == "eval":
        gen_text = tester.tokenizer.batch_decode(results)
        print("Generated text: ", gen_text)

    record_property("torch_ttnn", (tester, results))

However, I get the following error:

E           AssertionError: Attempt to trace forbidden callable <function mark_static_address at 0x744a0d11d3a0>
E
E           from user code:
E              File "/localdev/ddilbaz/tt-torch/env/venv/lib/python3.11/site-packages/transformers/models/mamba/modeling_mamba.py", line 775, in forward
E               mamba_outputs = self.backbone(
E             File "/localdev/ddilbaz/tt-torch/env/venv/lib/python3.11/site-packages/transformers/models/mamba/modeling_mamba.py", line 603, in forward
E               cache_params = MambaCache(
E             File "/localdev/ddilbaz/tt-torch/env/venv/lib/python3.11/site-packages/transformers/cache_utils.py", line 1829, in __init__
E               torch._dynamo.mark_static_address(self.conv_states)
E
E
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

env/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py:829: AssertionError

I can suppress these errors and run the model, however I do not know the root cause of this error. I would appreciate any tips. I have tried updating torch/ torchvision yet the issue persists.
@AleksKnezevic @mmanzoorTT

ddilbazTT added a commit that referenced this issue Dec 16, 2024
Add tests for mamba and mamba2. Tests are marked xfail because they yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' error. However, the tests generate a graph.
ddilbazTT added a commit that referenced this issue Dec 16, 2024
Add tests for mamba and mamba2. Tests are marked xfail because they yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' error. However, the tests generate a graph.
ddilbazTT added a commit that referenced this issue Jan 13, 2025
Add tests for mamba and mamba2. Tests are marked xfail because they yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' error. However, the tests generate a graph.
ddilbazTT added a commit that referenced this issue Jan 20, 2025
Add tests for mamba and mamba2. Tests are marked xfail because they yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' error. However, the tests generate a graph.
ddilbazTT added a commit that referenced this issue Jan 20, 2025
Add tests for mamba and mamba2. Disabled cache because the tests yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' otherwise.
ddilbazTT added a commit that referenced this issue Jan 20, 2025
Add tests for mamba and mamba2. Disabled cache because the tests yield 'AssertionError: Attempt to
trace forbidden callable <function mark_static_address at
0x744a0d11d3a0>' otherwise.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants