From 2df326fdce6f4e9369381c045c16d9afda62b1ce Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 11 Jun 2024 11:43:51 +0000 Subject: [PATCH] updated test model logprobs --- tests/models/test_models_logprobs.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/models/test_models_logprobs.py b/tests/models/test_models_logprobs.py index 04c172e0a7942..a07d4e1e5d89e 100644 --- a/tests/models/test_models_logprobs.py +++ b/tests/models/test_models_logprobs.py @@ -13,7 +13,7 @@ "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", - "tiiuae/falcon-7b", + "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", @@ -33,8 +33,9 @@ "bigcode/starcoder2-3b", ] -SKIPPED_MODELS_OOM = [ - "EleutherAI/gpt-j-6b", +SKIPPED_MODELS_CI = [ + "EleutherAI/gpt-j-6b", # OOM on CPU RAM + "tiiuae/falcon-7b", # Fails in vllm if trust_remote_code=True ] @@ -54,9 +55,9 @@ def test_models( if model in SKIPPED_MODELS_ACC: pytest.skip(reason="Low priority models not currently passing. " "We need to re-enable these.") - if model in SKIPPED_MODELS_OOM: - pytest.skip(reason="These models cause OOM issue on the CPU" - "because it is a fp32 checkpoint.") + if model in SKIPPED_MODELS_CI: + pytest.skip(reason="These models cause some CI issue unrelated " + "to the correctness of the implementation.") hf_model = hf_runner_nm(model, dtype=dtype) hf_outputs = hf_model.generate_greedy_logprobs_nm(example_prompts, @@ -64,13 +65,8 @@ def test_models( del hf_model - trust_remote_code = True - # Falcon fails if trust_remote_code = True - # https://github.com/vllm-project/vllm/issues/5363 - trust_remote_code = model != "tiiuae/falcon-7b" vllm_model = vllm_runner_nm(model, dtype=dtype, - trust_remote_code=trust_remote_code, max_model_len=MODEL_MAX_LEN) vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,