diff --git a/promptbench/models/__init__.py b/promptbench/models/__init__.py index 892694a..7dacf3c 100644 --- a/promptbench/models/__init__.py +++ b/promptbench/models/__init__.py @@ -81,8 +81,10 @@ def _create_model(self, max_new_tokens, temperature, device, dtype, model_dir, s if model_class: if model_class == LlamaModel or model_class == VicunaModel: return model_class(self.model_name, max_new_tokens, temperature, device, dtype, system_prompt, model_dir) - elif model_class in [OpenAIModel, PaLMModel, GeminiModel]: - return model_class(self.model_name, max_new_tokens, temperature, device, dtype, system_prompt, api_key) + elif model_class in [OpenAIModel]: + return model_class(self.model_name, max_new_tokens, temperature, system_prompt, api_key) + elif model_class in [PaLMModel, GeminiModel]: + return model_class(self.model_name, max_new_tokens, temperature, api_key) else: return model_class(self.model_name, max_new_tokens, temperature, device, dtype) else: @@ -156,4 +158,4 @@ def _other_concat_prompts(self, prompt_list): def __call__(self, input_text, **kwargs): """Predicts the output based on the given input text using the loaded model.""" - return self.model.predict(input_text, **kwargs) \ No newline at end of file + return self.model.predict(input_text, **kwargs)