From 6752719b3bae1d674b55f10a91b1b91733f3e6b5 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Wed, 22 Feb 2023 23:06:37 +0800 Subject: [PATCH] cache prompt variables --- promptify/prompts/nlp/prompter.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/promptify/prompts/nlp/prompter.py b/promptify/prompts/nlp/prompter.py index 504a6a9..f67f590 100644 --- a/promptify/prompts/nlp/prompter.py +++ b/promptify/prompts/nlp/prompter.py @@ -19,14 +19,18 @@ def __init__( self.allowed_missing_variables = allowed_missing_variables self.model_args_count = self.model.run.__code__.co_argcount self.model_variables = self.model.run.__code__.co_varnames[1 : self.model_args_count] + self.prompt_variables_map = {} def list_templates(self) -> List[str]: return self.environment.list_templates() def get_template_variables(self, template_name: str) -> List[str]: + if template_name in self.prompt_variables_map: + return self.prompt_variables_map[template_name] template_source = self.environment.loader.get_source(self.environment, template_name) parsed_content = self.environment.parse(template_source) undeclared_variables = meta.find_undeclared_variables(parsed_content) + self.prompt_variables_map[template_name] = undeclared_variables return undeclared_variables def generate_prompt(self, template_name, **kwargs) -> str: @@ -44,11 +48,11 @@ def fit(self, template_name, **kwargs): prompt_variables = self.get_template_variables(template_name) prompt_kwargs = {} model_kwargs = {} - for variable in kwargs: - if variable in prompt_variables: - prompt_kwargs[variable] = kwargs[variable] - elif variable in self.model_variables: - model_kwargs[variable] = kwargs[variable] + for var_name, val in kwargs.items(): + if var_name in prompt_variables: + prompt_kwargs[var_name] = val + elif var_name in self.model_variables: + model_kwargs[var_name] = val prompt = self.generate_prompt(template_name, **prompt_kwargs) output = self.model.run(prompts=[prompt], **model_kwargs) return output[0]