diff --git a/flaml/__init__.py b/flaml/__init__.py index 9fca486499..9cd8ead0da 100644 --- a/flaml/__init__.py +++ b/flaml/__init__.py @@ -3,6 +3,7 @@ from flaml.tune.searcher import CFO, BlendSearch, FLOW2, BlendSearchTuner, RandomSearch from flaml.onlineml.autovw import AutoVW from flaml.autogen import oai +from flaml.autogen import icl from flaml.version import __version__ diff --git a/flaml/autogen/icl/__init__.py b/flaml/autogen/icl/__init__.py new file mode 100644 index 0000000000..a2ed99aba3 --- /dev/null +++ b/flaml/autogen/icl/__init__.py @@ -0,0 +1,3 @@ +from flaml.autogen.icl.exemplar_selector import ExemplarSelector + +__all__ = ["ExemplarSelector"] diff --git a/flaml/autogen/icl/exemplar_selector.py b/flaml/autogen/icl/exemplar_selector.py new file mode 100644 index 0000000000..3d79a162a3 --- /dev/null +++ b/flaml/autogen/icl/exemplar_selector.py @@ -0,0 +1,58 @@ +from functools import partial +import random +import numpy as np +from flaml.autogen.icl.selection_methods import RandomSelect + +class ExemplarSelector: + METHOD_MAPPING = { + "random": RandomSelect, + # You can add more methods here... + } + + @classmethod + def get_few_shot_template(cls, train_data, method=None, few_shot_template=None, method_params=None, template_params=None): + if isinstance(method, str): + method_class = cls.METHOD_MAPPING.get(method, None) + if method_class is not None: + method = method_class(train_data, **method_params).select + else: + raise ValueError(f"The specified method '{method}' is not recognized.") + return partial(cls.construct_template, train_data=train_data, method=method, + few_shot_template=few_shot_template, method_params=method_params or {}, + template_params=template_params or {}) + + @staticmethod + def construct_template(context, train_data, method=None, few_shot_template=None, method_params=None, template_params=None): + + if method is None: + k = method_params.get('k', np.inf) + exemplars = train_data[:k] if len(train_data) >= k else train_data + else: + exemplars = method(context) + + if few_shot_template is not None: + return few_shot_template(context, exemplars=exemplars) + else: + if 'key_order' not in template_params: + raise ValueError("No 'key_order' found in 'template_params'. 'key_order' is required when no 'few_shot_template' is provided.") + + key_order = template_params.get('key_order', None) + return ExemplarSelector.default_template(context, exemplars, key_order) + + @staticmethod + def default_template(context, exemplars, key_order): + few_shot_prompt = "" + for exemplar in exemplars: + few_shot_prompt += "\n".join( + [ + key + ": " + str(exemplar[key]) for key in key_order + ] + ) + "\n" + few_shot_prompt += "\n".join( + [ + key + ": " + str(context[key]) for key in key_order[:-1] + ] + ) + few_shot_prompt += "\n" + key_order[-1] + ": " + "\n" + return few_shot_prompt + diff --git a/flaml/autogen/icl/selection_methods.py b/flaml/autogen/icl/selection_methods.py new file mode 100644 index 0000000000..4f44d0019a --- /dev/null +++ b/flaml/autogen/icl/selection_methods.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +import random + +class SelectMethod(ABC): + @abstractmethod + def __init__(self, data): + pass + + @abstractmethod + def select(self, context): + pass + +class RandomSelect(SelectMethod): + def __init__(self, data, k=None): + self.data = data + self.k = k + + def select(self, context): + data_without_context = [item for item in self.data if item != context] + if self.k is None: + return data_without_context + else: + return random.sample(data_without_context, self.k) + + diff --git a/flaml/autogen/oai/__init__.py b/flaml/autogen/oai/__init__.py index 5457351b1b..2cf4820587 100644 --- a/flaml/autogen/oai/__init__.py +++ b/flaml/autogen/oai/__init__.py @@ -1,4 +1,4 @@ from flaml.autogen.oai.completion import Completion, ChatCompletion from flaml.autogen.oai.openai_utils import get_config_list, config_list_gpt4_gpt35, config_list_openai_aoai -__all__ = ["Completion", "ChatCompletion", "get_config_list", "config_list_gpt4_gpt35", "config_list_openai_aoai"] +__all__ = ["Completion", "ChatCompletion", "get_config_list", "config_list_gpt4_gpt35", "config_list_openai_aoai"] \ No newline at end of file diff --git a/test/autogen/icl/test_selection.py b/test/autogen/icl/test_selection.py new file mode 100644 index 0000000000..05c45fa13a --- /dev/null +++ b/test/autogen/icl/test_selection.py @@ -0,0 +1,93 @@ +import unittest +import datasets +from flaml import icl + +class TestExemplarSelector(unittest.TestCase): + ''' + template: default or user-defined, if default, key_order is specific + method: existing (str) or user-defined (func) or None (return in the order of train_data) + ''' + @classmethod + def setUpClass(cls): + seed = 41 + cls.data = datasets.load_dataset("piqa") + cls.train_data = cls.data["train"].shuffle(seed=seed) + cls.test_data = cls.data["test"].shuffle(seed=seed) + cls.key_order = ["goal", "sol1", "sol2", "label"] + cls.exemplars = list(cls.train_data)[:8] + cls.context = list(cls.test_data)[0] + + def test_case_existing_method_default_template(self): + # Most cases should use the default template and existing methods + prompt_fn = icl.ExemplarSelector.get_few_shot_template(self.exemplars, method="random", + method_params={"k": 3}, template_params={"key_order": self.key_order}) + output= prompt_fn(self.context) + #print("Existing method + default template: prompt = ", output) + self.assertIsInstance(output, str) + self.assertIn(self.context[self.key_order[0]], output) + + def test_case_user_template_no_method(self): + # User specify their own template + method is not specified, k is specific + key_order = ["sol1", "sol2", "label"] + def few_shot_template(context, exemplars=None): + few_shot_prompt = "User template:\n" + + for exemplar in exemplars: + few_shot_prompt += "\n".join( + [ + key + ": " + str(exemplar[key]) for key in key_order + ] + ) + "\n" + few_shot_prompt += "\n".join( + [ + key + ": " + str(context[key]) for key in key_order[:-1] + ] + ) + few_shot_prompt += "\n" + key_order[-1] + ": " + "\n" + return few_shot_prompt + prompt_fn = icl.ExemplarSelector.get_few_shot_template(self.exemplars, + few_shot_template=few_shot_template, + method_params={"k": 3}) + output= prompt_fn(self.context) + #print("test_user_template_No_Method: prompt = ", output) + self.assertIsInstance(output, str) + self.assertIn(self.context[key_order[0]], output) + self.assertIn(self.context[key_order[1]], output) + # should pick first 3 exemplars + self.assertIn(self.exemplars[0][key_order[1]], output) + self.assertIn(self.exemplars[1][key_order[1]], output) + self.assertIn(self.exemplars[2][key_order[1]], output) + self.assertNotIn(self.exemplars[3][key_order[1]], output) + + + def test_case_user_method_no_k(self): + # User specify their method is not specified, k is not specific + def user_method(context): + return self.exemplars[3:5] + # key_order should be provided if we use the default template + prompt_fn = icl.ExemplarSelector.get_few_shot_template(self.exemplars, + method = user_method, + template_params={"key_order": self.key_order}) + output= prompt_fn(self.context) + #print("test_user_method_no_k: prompt = ", output) + self.assertIsInstance(output, str) + self.assertIn(self.context[self.key_order[0]], output) + self.assertIn(self.context[self.key_order[1]], output) + self.assertIn(self.context[self.key_order[2]], output) + # should pick the 3rd,4th exemplars + self.assertIn(self.exemplars[3][self.key_order[2]], output) + self.assertIn(self.exemplars[4][self.key_order[2]], output) + self.assertNotIn(self.exemplars[5][self.key_order[2]], output) + self.assertNotIn(self.exemplars[2][self.key_order[2]], output) + + def test_invalid_method(self): + with self.assertRaises(ValueError): + icl.ExemplarSelector.get_few_shot_template(self.exemplars, method="nonexistent") + + + + +if __name__ == '__main__': + unittest.main() + +