Skip to content

Commit

Permalink
[fix] bugs in loading api-based models
Browse files Browse the repository at this point in the history
  • Loading branch information
Immortalise authored Jan 1, 2024
1 parent 5ca7aa3 commit 935ad6a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions promptbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ def _create_model(self, max_new_tokens, temperature, device, dtype, model_dir, s
if model_class:
if model_class == LlamaModel or model_class == VicunaModel:
return model_class(self.model_name, max_new_tokens, temperature, device, dtype, system_prompt, model_dir)
elif model_class in [OpenAIModel, PaLMModel, GeminiModel]:
return model_class(self.model_name, max_new_tokens, temperature, device, dtype, system_prompt, api_key)
elif model_class in [OpenAIModel]:
return model_class(self.model_name, max_new_tokens, temperature, system_prompt, api_key)
elif model_class in [PaLMModel, GeminiModel]:
return model_class(self.model_name, max_new_tokens, temperature, api_key)
else:
return model_class(self.model_name, max_new_tokens, temperature, device, dtype)
else:
Expand Down Expand Up @@ -156,4 +158,4 @@ def _other_concat_prompts(self, prompt_list):

def __call__(self, input_text, **kwargs):
"""Predicts the output based on the given input text using the loaded model."""
return self.model.predict(input_text, **kwargs)
return self.model.predict(input_text, **kwargs)

0 comments on commit 935ad6a

Please sign in to comment.