diff --git a/demo/cmd_getting_started/vw_intro.ipynb b/demo/cmd_getting_started/vw_intro.ipynb new file mode 100644 index 00000000000..724697a475c --- /dev/null +++ b/demo/cmd_getting_started/vw_intro.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Helpers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def read(path):\n", + " with open(path) as f:\n", + " print(\"\".join(f.readlines()))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Regression" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's generate data of the following form:
\n", + "Every example has single namespace 'f' with single feature 'x' in it
\n", + "Target function is $$\\hat{y} = 2x + 1$$\n", + "And we are learning weights $w$, $b$ for $$y=wx+b$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "with open(\"regression1.txt\", \"w\") as f:\n", + " for i in range(1000):\n", + " x = np.random.rand()\n", + " y = 2 * x + 1\n", + " f.write(f\"{y} |f x:{x}\\n\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simplest execution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression1.txt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output more artifacts\n", + "-p - predictions
\n", + "--invert_hash - model in readable format
\n", + "-f - model in binary format (consumable by vw)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression1.txt -p regression1.pred --invert_hash regression1.model.txt -f regression1.model.bin" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can look at weights and see the $w$ and $b$ got close to expected 2 and 1 values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "read(\"regression1.model.txt\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Do only predictions, no learning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression1.txt -t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression1.txt -t --learning_rate 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression1.txt -t -i regression1.model.bin" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interactions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's generate another dataset of the following form:
\n", + "Every example has single namespace 'f' with single feature 'x' in it
\n", + "Target function is $$\\hat{y} = x^2 + 1$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "with open(\"regression2.txt\", \"w\") as f:\n", + " for i in range(1000):\n", + " x = np.random.rand() * 4\n", + " y = x**2 + 1\n", + " f.write(f\"{y} |f x:{x}\\n\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see loss being far from zero if we stil try to learn $$y=wx+b$$ " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression2.txt --invert_hash regression2.model.txt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So let's try to learn $$y=w_1 x^2 + w_2 x + b$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw -d regression2.txt --invert_hash regression2.model.txt --interactions ff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "read(\"regression2.model.txt\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Contextual bandits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "env = {\"Tom\": {\"sports\": 0, \"politics\": 1}, \"Anna\": {\"sports\": 1, \"politics\": 0}}\n", + "\n", + "users = [\"Tom\", \"Anna\"]\n", + "content = [\"sports\", \"politics\"]\n", + "\n", + "with open(\"cb.txt\", \"w\") as f:\n", + " for i in range(1000):\n", + " user = users[np.random.randint(0, 2)]\n", + " chosen = np.random.randint(0, 2)\n", + " reward = env[user][content[chosen]]\n", + "\n", + " f.write(f\"shared |u {user}\\n\")\n", + " f.write(f\"0:{-reward}:0.5 |a {content[chosen]}\\n\")\n", + " f.write(f\"|a {content[(chosen + 1) % 2]}\\n\\n\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try to learn to predict reward in the following form: $$r = w_1 I(user\\ is\\ Tom) + w_2 I(user\\ is\\ Anna) + w_3 I(topic\\ is\\ sports) + w_4 I(topic\\ is\\ politics) + b$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw --cb_explore_adf -d cb.txt --invert_hash cb.model.txt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that average reward is still around 0.5 which is the same as we would get answering randomly. This is expected since personalization is not captured in this form.\n", + "Let's add interaction between 'u' and 'a' namespaces and try to learn function of the following form:\n", + "$$\\begin{aligned}r = w_1 I(user\\ is\\ Tom) I(topic\\ is\\ sports) + w_2 I(user\\ is\\ Tom) I(topic\\ is\\ politics) +\\\\+ w_3 I(user\\ is\\ Anna) I(topic\\ is\\ sports) + w_4 I(user\\ is\\ Anna) I(topic\\ is\\ politics) +\\\\+ w_5 I(user\\ is\\ Tom) + w_6 I(user\\ is\\ Anna) +\\\\+ w_7 I(topic\\ is\\ sports) + w_8 I(topic\\ is\\ politics) + b\\end{aligned}$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!vw --cb_explore_adf -d cb.txt --invert_hash cb.model.txt --interactions ua" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "read(\"cb.model.txt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/tests/e2e_v2/assert_job.py b/python/tests/e2e_v2/assert_job.py new file mode 100644 index 00000000000..4acc1389ec3 --- /dev/null +++ b/python/tests/e2e_v2/assert_job.py @@ -0,0 +1,121 @@ +import numpy as np +import os +from numpy.testing import assert_allclose, assert_almost_equal +from vw_executor.vw import ExecutionStatus +import vowpalwabbit as vw +from test_helper import get_function_object, datagen_driver + + +def remove_non_digits(string): + return "".join(char for char in string if char.isdigit() or char == ".") + + +def get_from_kwargs(kwargs, key, default=None): + if key in kwargs: + return kwargs[key] + else: + return default + + +def majority_close(arr1, arr2, rtol, atol, threshold): + # Check if the majority of elements are close + close_count = np.count_nonzero(np.isclose(arr1, arr2, rtol=rtol, atol=atol)) + return close_count >= len(arr1) * threshold + + +def assert_weight(job, **kwargs): + atol = get_from_kwargs(kwargs, "atol", 10e-8) + rtol = get_from_kwargs(kwargs, "rtol", 10e-5) + expected_weights = kwargs["expected_weights"] + assert job.status == ExecutionStatus.Success, f"{job.opts} job should be successful" + data = job.outputs["--invert_hash"] + with open(data[0], "r") as f: + data = f.readlines() + data = [i.strip() for i in data] + weights = job[0].model9("--invert_hash").weights.to_dict()["weight"] + for x in expected_weights: + assert_allclose( + [weights[x]], [expected_weights[x]], atol=atol, rtol=rtol + ), f"weights {x} should be {expected_weights[x]}" + + +def assert_prediction(job, **kwargs): + assert job.status == ExecutionStatus.Success, "job should be successful" + atol = kwargs.get("atol", 10e-8) + rtol = kwargs.get("rtol", 10e-5) + threshold = kwargs.get("threshold", 0.9) + expected_value = kwargs["expected_value"] + predictions = job.outputs["-p"] + with open(predictions[0], "r") as f: + prediction = [i.strip() for i in f.readlines()] + prediction = [i for i in prediction if i != ""] + if ":" in prediction[0]: + prediction = [[j.split(":")[1] for j in i.split(",")] for i in prediction] + elif "," in prediction[0]: + prediction = [[j for j in i.split(",")] for i in prediction] + if type(prediction[0]) == list: + prediction = [[float(remove_non_digits(j)) for j in i] for i in prediction] + else: + prediction = [float(remove_non_digits(i)) for i in prediction] + assert majority_close( + prediction, + [expected_value] * len(prediction), + rtol=rtol, + atol=atol, + threshold=threshold, + ), f"predicted value should be {expected_value}, \n actual values are {prediction}" + + +def assert_loss(job, **kwargs): + assert job.status == ExecutionStatus.Success, "job should be successful" + assert type(job[0].loss) == float, "loss should be an float" + decimal = kwargs.get("decimal", 2) + assert_almost_equal(job[0].loss, kwargs["expected_loss"], decimal=decimal) + + +def assert_loss_below(job, **kwargs): + assert job.status == ExecutionStatus.Success, "job should be successful" + assert type(job[0].loss) == float, "loss should be an float" + assert ( + job[0].loss <= kwargs["expected_loss"] + ), f"loss should be below {kwargs['expected_loss']}" + + +def assert_prediction_with_generated_data(job, **kwargs): + assert job.status == ExecutionStatus.Success, "job should be successful" + expected_class = [] + trained_model = vw.Workspace(f"-i {job[0].model9('-f').path} --quiet") + predictions = [] + folder_path = os.path.dirname(os.path.realpath(__file__)) + subdirectories = [ + os.path.join(folder_path, name) + for name in os.listdir(folder_path) + if os.path.isdir(os.path.join(folder_path, name)) + ] + for subdir in subdirectories: + try: + subdir_name = subdir.replace("\\", "/").split("/")[-1] + data_func_obj = get_function_object( + f"{subdir_name}.data_generation", kwargs["data_func"]["name"] + ) + if data_func_obj: + break + except: + pass + script_directory = os.path.dirname(os.path.realpath(__file__)) + dataFile = datagen_driver( + os.path.join(script_directory, subdir_name), + data_func_obj, + **kwargs["data_func"]["params"], + ) + with open(dataFile, "r") as f: + for line in f.readlines(): + expected_class.append(line.split("|")[0].strip()) + predicted_class = trained_model.predict(line.strip()) + predictions.append(predicted_class) + accuracy = sum( + [1 if int(yp) == int(ye) else 0 for yp, ye in zip(predictions, expected_class)] + ) / len(expected_class) + assert ( + accuracy >= kwargs["accuracy_threshold"] + ), f"Accuracy is {accuracy} and Threshold is {kwargs['accuracy_threshold']}" diff --git a/python/tests/e2e_v2/cb/data_generation.py b/python/tests/e2e_v2/cb/data_generation.py new file mode 100644 index 00000000000..d7c9bbb702f --- /dev/null +++ b/python/tests/e2e_v2/cb/data_generation.py @@ -0,0 +1,67 @@ +import random +import os +from test_helper import get_function_object + +script_directory = os.path.dirname(os.path.realpath(__file__)) + + +def random_number_items(items): + num_items_to_select = random.randint(1, len(items)) + return random.sample(items, num_items_to_select) + + +def generate_cb_data( + f, + num_examples, + num_features, + num_actions, + reward_function, + logging_policy, + context_name=["1"], + seed=random.randint(0, 100), +): + random.seed(seed) + + reward_function_obj = get_function_object( + "cb.reward_functions", reward_function["name"] + ) + logging_policy_obj = get_function_object( + "cb.logging_policies", logging_policy["name"] + ) + features = [f"feature{index}" for index in range(1, num_features + 1)] + for _ in range(num_examples): + no_context = len(context_name) + if no_context > 1: + context = random.randint(1, no_context) + else: + context = 1 + + def return_cost_probability(chosen_action, context=1): + cost = -reward_function_obj( + chosen_action, context, **reward_function["params"] + ) + if "params" not in logging_policy: + logging_policy["params"] = {} + logging_policy["params"]["chosen_action"] = chosen_action + logging_policy["params"]["num_actions"] = num_actions + probability = logging_policy_obj(**logging_policy["params"]) + return cost, probability + + chosen_action = random.randint(1, num_actions) + if no_context > 1: + f.write(f"shared | User s_{context_name[context-1]}\n") + for action in range(1, num_actions + 1): + cost, probability = return_cost_probability(action, context) + if action == chosen_action: + f.write( + f'{action}:{cost}:{probability} | {" ".join(random_number_items(features))}\n' + ) + else: + f.write(f'| {" ".join(random_number_items(features))}\n') + + else: + cost, probability = return_cost_probability(chosen_action) + f.write( + f'{chosen_action}:{cost}:{probability} | {" ".join(random_number_items(features))}\n' + ) + f.write("\n") diff --git a/python/tests/e2e_v2/cb/logging_policies.py b/python/tests/e2e_v2/cb/logging_policies.py new file mode 100644 index 00000000000..7f6f951fffa --- /dev/null +++ b/python/tests/e2e_v2/cb/logging_policies.py @@ -0,0 +1,6 @@ +def constant_probability(chosen_action): + return 1 + + +def even_probability(chosen_action, num_actions): + return round(1 / num_actions, 2) diff --git a/python/tests/e2e_v2/cb/reward_functions.py b/python/tests/e2e_v2/cb/reward_functions.py new file mode 100644 index 00000000000..a13a699d0a4 --- /dev/null +++ b/python/tests/e2e_v2/cb/reward_functions.py @@ -0,0 +1,18 @@ +def fixed_reward(chosen_action, context): + return 1 + + +def constant_reward(chosen_action, context, reward): + return reward[chosen_action - 1] + + +def fixed_reward_two_action(chosen_action, context): + if context == 1 and chosen_action == 2: + return 1 + elif context == 2 and chosen_action == 2: + return 0 + elif context == 1 and chosen_action == 1: + return 0 + elif context == 2 and chosen_action == 1: + return 1 + return 1 diff --git a/python/tests/e2e_v2/cb_cont/data_generation.py b/python/tests/e2e_v2/cb_cont/data_generation.py new file mode 100644 index 00000000000..df2fa0de536 --- /dev/null +++ b/python/tests/e2e_v2/cb_cont/data_generation.py @@ -0,0 +1,62 @@ +import random +import os +from test_helper import get_function_object + +script_directory = os.path.dirname(os.path.realpath(__file__)) + + +def random_number_items(items): + num_items_to_select = random.randint(1, len(items)) + return random.sample(items, num_items_to_select) + + +def generate_cb_data( + f, + num_examples, + num_features, + action_range, + reward_function, + logging_policy, + context_name=["1"], + seed=random.randint(0, 100), +): + random.seed(seed) + num_actions = int(abs(action_range[1] - action_range[0])) + + reward_function_obj = get_function_object( + "cb_cont.reward_functions", reward_function["name"] + ) + logging_policy_obj = get_function_object( + "cb_cont.logging_policies", logging_policy["name"] + ) + features = [f"feature{index}" for index in range(1, num_features + 1)] + + for _ in range(num_examples): + no_context = len(context_name) + if no_context > 1: + context = random.randint(1, no_context) + else: + context = 1 + + def return_cost_probability(chosen_action, context): + cost = -reward_function_obj( + chosen_action, context, **reward_function["params"] + ) + if "params" not in logging_policy: + logging_policy["params"] = {} + logging_policy["params"]["chosen_action"] = chosen_action + logging_policy["params"]["num_actions"] = num_actions + probability = logging_policy_obj(**logging_policy["params"]) + return cost, probability + + chosen_action = round(random.uniform(0, num_actions), 2) + cost, probability = return_cost_probability(chosen_action, context) + if no_context == 1: + f.write( + f'ca {chosen_action}:{cost}:{probability} | {" ".join(random_number_items(features))}\n' + ) + else: + f.write( + f'ca {chosen_action}:{cost}:{probability} | {"s_" + context_name[context-1]} {" ".join(random_number_items(features))}\n' + ) + f.write("\n") diff --git a/python/tests/e2e_v2/cb_cont/logging_policies.py b/python/tests/e2e_v2/cb_cont/logging_policies.py new file mode 100644 index 00000000000..7f6f951fffa --- /dev/null +++ b/python/tests/e2e_v2/cb_cont/logging_policies.py @@ -0,0 +1,6 @@ +def constant_probability(chosen_action): + return 1 + + +def even_probability(chosen_action, num_actions): + return round(1 / num_actions, 2) diff --git a/python/tests/e2e_v2/cb_cont/reward_functions.py b/python/tests/e2e_v2/cb_cont/reward_functions.py new file mode 100644 index 00000000000..c865e666084 --- /dev/null +++ b/python/tests/e2e_v2/cb_cont/reward_functions.py @@ -0,0 +1,18 @@ +def fixed_reward(chosen_action, context): + return 1 + + +def piecewise_constant(chosen_action, context, reward): + return reward[int(chosen_action) - 1] + + +def fixed_reward_two_action(chosen_action, context): + if context == 1 and chosen_action >= 2: + return 1 + elif context == 2 and chosen_action < 2 and chosen_action >= 1: + return 0 + elif context == 1 and chosen_action < 1 and chosen_action >= 1: + return 0 + elif context == 2 and chosen_action < 1: + return 1 + return 1 diff --git a/python/tests/e2e_v2/classification/classification_functions.py b/python/tests/e2e_v2/classification/classification_functions.py new file mode 100644 index 00000000000..67aeb8a77d2 --- /dev/null +++ b/python/tests/e2e_v2/classification/classification_functions.py @@ -0,0 +1,19 @@ +def binary_classification_one_feature(input_vector): + if input_vector[0] > 0.5: + return 2 + return 1 + + +def multi_classification_two_features(input_vector): + # Define the number of divisions for each feature + divisions = 5 + + # Calculate the division size for each feature + division_size = 1 / divisions + + # Calculate the class index based on the input vector's position in the feature space + class_idx = int(input_vector[0] // division_size) * divisions + int( + input_vector[1] // division_size + ) + + return class_idx + 1 diff --git a/python/tests/e2e_v2/classification/data_generation.py b/python/tests/e2e_v2/classification/data_generation.py new file mode 100644 index 00000000000..c65e49b3350 --- /dev/null +++ b/python/tests/e2e_v2/classification/data_generation.py @@ -0,0 +1,29 @@ +import os, random +from test_helper import get_function_object + + +script_directory = os.path.dirname(os.path.realpath(__file__)) +random.seed(10) + + +def generate_classification_data( + f, + num_example, + num_features, + classify_func, + seed=random.randint(0, 100), + bounds=None, +): + random.seed(seed) + classify_func_obj = get_function_object( + "classification.classification_functions", classify_func["name"] + ) + if not bounds: + bounds = [[0, 1] for _ in range(num_features)] + for _ in range(num_example): + x = [ + random.uniform(bounds[index][0], bounds[index][1]) + for index in range(num_features) + ] + y = classify_func_obj(x, **classify_func["params"]) + f.write(f"{y} |f {' '.join([f'x{i}:{x[i]}' for i in range(num_features)])}\n") diff --git a/python/tests/e2e_v2/conftest.py b/python/tests/e2e_v2/conftest.py new file mode 100644 index 00000000000..5297e54cec2 --- /dev/null +++ b/python/tests/e2e_v2/conftest.py @@ -0,0 +1,16 @@ +# conftest.py +def pytest_addoption(parser): + parser.addoption( + "--store_output", + action="store", + default=False, + help="Store output file for tests.", + ) + + +def pytest_configure(config): + _store_output = config.getoption("--store_output") + # Store the custom_arg_value in a global variable or a custom configuration object. + # For example, you can store it in a global variable like this: + global STORE_OUTPUT + STORE_OUTPUT = _store_output diff --git a/python/tests/e2e_v2/regression/data_generation.py b/python/tests/e2e_v2/regression/data_generation.py new file mode 100644 index 00000000000..cf6e8bca04c --- /dev/null +++ b/python/tests/e2e_v2/regression/data_generation.py @@ -0,0 +1,13 @@ +import random +import os + +script_directory = os.path.dirname(os.path.realpath(__file__)) + + +def constant_function( + f, no_sample, constant, x_lower_bound, x_upper_bound, seed=random.randint(0, 100) +): + random.seed(seed) + for _ in range(no_sample): + x = random.uniform(x_lower_bound, x_upper_bound) + f.write(f"{constant} |f x:{x}\n") diff --git a/python/tests/e2e_v2/slate/action_space.py b/python/tests/e2e_v2/slate/action_space.py new file mode 100644 index 00000000000..bd13b796357 --- /dev/null +++ b/python/tests/e2e_v2/slate/action_space.py @@ -0,0 +1,5 @@ +def new_action_after_threshold(iteration, threshold, before, after): + # before iteration 500, it is sunny and after it is raining + if iteration > threshold: + return after + return before diff --git a/python/tests/e2e_v2/slate/assert_job.py b/python/tests/e2e_v2/slate/assert_job.py new file mode 100644 index 00000000000..aff9061df6e --- /dev/null +++ b/python/tests/e2e_v2/slate/assert_job.py @@ -0,0 +1,42 @@ +from vw_executor.vw import ExecutionStatus +import numpy as np + + +def majority_close(arr1, arr2, rtol, atol, threshold): + # Check if the majority of elements are close + close_count = np.count_nonzero(np.isclose(arr1, arr2, rtol=rtol, atol=atol)) + return close_count >= len(arr1) * threshold + + +def assert_prediction(job, **kwargs): + assert job.status == ExecutionStatus.Success, "job should be successful" + atol = kwargs.get("atol", 10e-8) + rtol = kwargs.get("rtol", 10e-5) + threshold = kwargs.get("threshold", 0.9) + expected_value = kwargs["expected_value"] + predictions = job.outputs["-p"] + res = [] + with open(predictions[0], "r") as f: + exampleRes = [] + while True: + line = f.readline() + if not line: + break + if line.count(":") == 0: + res.append(exampleRes) + exampleRes = [] + continue + slotRes = [0] * line.count(":") + slot = line.split(",") + for i in range(len(slot)): + actionInd = int(slot[i].split(":")[0]) + slotRes[i] = float(slot[actionInd].split(":")[1]) + exampleRes.append(slotRes) + + assert majority_close( + res, + [expected_value] * len(res), + rtol=rtol, + atol=atol, + threshold=threshold, + ), f"predicted value should be {expected_value}, \n actual values are {res}" diff --git a/python/tests/e2e_v2/slate/data_generation.py b/python/tests/e2e_v2/slate/data_generation.py new file mode 100644 index 00000000000..d308b245d3d --- /dev/null +++ b/python/tests/e2e_v2/slate/data_generation.py @@ -0,0 +1,69 @@ +import random +import os +from test_helper import get_function_object + +script_directory = os.path.dirname(os.path.realpath(__file__)) + + +def generate_slate_data( + f, + num_examples, + reward_function, + logging_policy, + action_space, + context_name=["1"], + seed=random.randint(0, 100), +): + random.seed(seed) + action_space_obj = get_function_object("slate.action_space", action_space["name"]) + + reward_function_obj = get_function_object( + "slate.reward_functions", reward_function["name"] + ) + logging_policy_obj = get_function_object( + "slate.logging_policies", logging_policy["name"] + ) + + def return_cost_probability(chosen_action, chosen_slot, context): + cost = -reward_function_obj( + chosen_action, context, chosen_slot, **reward_function["params"] + ) + logging_policy["params"]["num_action"] = num_actions[chosen_slot - 1] + logging_policy["params"]["chosen_action"] = chosen_action + probability = logging_policy_obj(**logging_policy["params"]) + return cost, probability + + for i in range(num_examples): + action_space["params"]["iteration"] = i + action_spaces = action_space_obj(**action_space["params"]) + reward_function["params"]["iteration"] = i + num_slots = len(action_spaces) + num_actions = [len(slot) for slot in action_spaces] + slot_name = [f"slot_{index}" for index in range(1, num_slots + 1)] + chosen_actions = [] + num_context = len(context_name) + if num_context > 1: + context = random.randint(1, num_context) + else: + context = 1 + for s in range(num_slots): + chosen_actions.append(random.randint(1, num_actions[s])) + chosen_actions_cost_prob = [ + return_cost_probability(action, slot + 1, context) + for slot, action in enumerate(chosen_actions) + ] + total_cost = sum([cost for cost, _ in chosen_actions_cost_prob]) + + f.write(f"slates shared {total_cost} |User {context_name[context-1]}\n") + # write actions + for ind, slot in enumerate(action_spaces): + for a in slot: + f.write( + f"slates action {ind} |Action {a}\n", + ) + + for s in range(num_slots): + f.write( + f"slates slot {chosen_actions[s]}:{chosen_actions_cost_prob[s][1]} |Slot {slot_name[s]}\n" + ) + f.write("\n") diff --git a/python/tests/e2e_v2/slate/logging_policies.py b/python/tests/e2e_v2/slate/logging_policies.py new file mode 100644 index 00000000000..4222a514b1f --- /dev/null +++ b/python/tests/e2e_v2/slate/logging_policies.py @@ -0,0 +1,2 @@ +def even_probability(chosen_action, num_action): + return round(1 / num_action, 2) diff --git a/python/tests/e2e_v2/slate/reward_functions.py b/python/tests/e2e_v2/slate/reward_functions.py new file mode 100644 index 00000000000..2dc79a5100f --- /dev/null +++ b/python/tests/e2e_v2/slate/reward_functions.py @@ -0,0 +1,10 @@ +def fixed_reward(chosen_action, context, slot, reward): + return reward[slot - 1][chosen_action - 1] + + +def reverse_reward_after_threshold( + chosen_action, context, slot, reward, iteration, threshold +): + if iteration > threshold: + reward = [i[::-1] for i in reward] + return reward[slot - 1][chosen_action - 1] diff --git a/python/tests/e2e_v2/test_configs/cb.json b/python/tests/e2e_v2/test_configs/cb.json new file mode 100644 index 00000000000..15fa7584fe2 --- /dev/null +++ b/python/tests/e2e_v2/test_configs/cb.json @@ -0,0 +1,270 @@ +[ + { + "test_name": "cb_two_action", + "data_func": { + "name": "generate_cb_data", + "params": { + "num_examples": 100, + "num_features": 1, + "num_actions": 2, + "seed": 1, + "reward_function": { + "name": "constant_reward", + "params": { + "reward": [ + 1, + 0 + ] + } + }, + "logging_policy": { + "name": "even_probability", + "params": {} + }, + "context_name": [ + "1", + "2" + ] + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -1, + "decimal": 1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": [ + 1, + 0 + ], + "threshold": 0.8 + } + } + ], + "grids": { + "cb": { + "#base": [ + "--cb_explore 2" + ] + }, + "epsilon": { + "--epsilon": [ + 0.1, + 0.2, + 0.3 + ] + }, + "first": { + "--first": [ + 1, + 2 + ] + }, + "bag": { + "--bag": [ + 5, + 6, + 7 + ] + }, + "cover": { + "--cover": [ + 1, + 2, + 3 + ] + }, + "squarecb": { + "--squarecb": [ + "--gamma_scale 1000", + "--gamma_scale 10000" + ] + }, + "synthcover": { + "--synthcover": [ + "" + ] + }, + "regcb": { + "--regcb": [ + "" + ] + }, + "softmax": { + "--softmax": [ + "" + ] + } + }, + "grids_expression": "cb * (epsilon + first + bag + cover + squarecb + synthcover + regcb + softmax)", + "output": [ + "--readable_model", + "-p" + ] + }, + { + "test_name": "cb_one_action", + "data_func": { + "name": "generate_cb_data", + "params": { + "num_examples": 100, + "num_features": 1, + "num_actions": 1, + "seed": 1, + "reward_function": { + "name": "fixed_reward", + "params": {} + }, + "logging_policy": { + "name": "even_probability" + } + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -1, + "decimal": 1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": 0, + "threshold": 0.1 + } + } + ], + "grids": { + "g0": { + "#base": [ + "--cb 1 --preserve_performance_counters --save_resume" + ] + }, + "g1": { + "--cb_type": [ + "ips", + "mtr", + "dr", + "dm" + ] + } + }, + "grids_expression": "g0 * g1", + "output": [ + "--readable_model", + "-p" + ] + }, + { + "test_name": "cb_two_action_diff_context", + "data_func": { + "name": "generate_cb_data", + "params": { + "num_examples": 100, + "num_features": 2, + "num_actions": 2, + "seed": 1, + "reward_function": { + "name": "fixed_reward_two_action", + "params": {} + }, + "logging_policy": { + "name": "even_probability", + "params": {} + }, + "context_name": [ + "1", + "2" + ] + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -0.5, + "decimal": 1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": [ + 0.975, + 0.025 + ], + "threshold": 0.1, + "atol": 0.1, + "rtol": 0.1 + } + } + ], + "grids": { + "cb": { + "#base": [ + "--cb_explore_adf" + ] + }, + "epsilon": { + "--epsilon": [ + 0.1, + 0.2, + 0.3 + ] + }, + "first": { + "--first": [ + 1, + 2 + ] + }, + "bag": { + "--bag": [ + 5, + 6, + 7 + ] + }, + "cover": { + "--cover": [ + 1, + 2, + 3 + ] + }, + "squarecb": { + "--squarecb": [ + "--gamma_scale 1000", + "--gamma_scale 10000" + ] + }, + "synthcover": { + "--synthcover": [ + "" + ] + }, + "regcb": { + "--regcb": [ + "" + ] + }, + "softmax": { + "--softmax": [ + "" + ] + } + }, + "grids_expression": "cb * (epsilon + first + bag + cover + squarecb + synthcover + regcb + softmax)", + "output": [ + "--readable_model", + "-p" + ] + } +] \ No newline at end of file diff --git a/python/tests/e2e_v2/test_configs/cb_cont.json b/python/tests/e2e_v2/test_configs/cb_cont.json new file mode 100644 index 00000000000..1a0d6e49066 --- /dev/null +++ b/python/tests/e2e_v2/test_configs/cb_cont.json @@ -0,0 +1,195 @@ +[ + { + "test_name": "cb_two_action", + "data_func": { + "name": "generate_cb_data", + "params": { + "num_examples": 100, + "num_features": 1, + "seed": 1, + "action_range": [ + 0, + 2 + ], + "reward_function": { + "name": "piecewise_constant", + "params": { + "reward": [ + 1, + 0 + ] + } + }, + "logging_policy": { + "name": "even_probability", + "params": {} + } + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -1, + "decimal": 1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": [ + 1, + 0 + ], + "threshold": 0.8 + } + } + ], + "grids": { + "cb": { + "#base": [ + "--cats 2 --min_value 0 --max_value 2 --bandwidth 1" + ] + }, + "epsilon": { + "--epsilon": [ + 0.1, + 0.2, + 0.3 + ] + } + }, + "grids_expression": "cb * (epsilon)", + "output": [ + "--readable_model", + "-p" + ] + }, + { + "test_name": "cb_two_action_diff_context", + "data_func": { + "name": "generate_cb_data", + "params": { + "num_examples": 100, + "num_features": 2, + "seed": 1, + "action_range": [ + 0, + 2 + ], + "reward_function": { + "name": "fixed_reward_two_action", + "params": {} + }, + "logging_policy": { + "name": "even_probability", + "params": {} + }, + "context_name": [ + "1", + "2" + ] + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -0.8, + "decimal": 1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": [ + 0.975, + 0.025 + ], + "threshold": 0.1, + "atol": 0.1, + "rtol": 0.1 + } + } + ], + "grids": { + "cb": { + "#base": [ + "--cats 2 --min_value 0 --max_value 2 --bandwidth 1" + ] + }, + "epsilon": { + "--epsilon": [ + 0.1, + 0.2, + 0.3 + ] + } + }, + "grids_expression": "cb * (epsilon)", + "output": [ + "--readable_model", + "-p" + ] + }, + { + "test_name": "cb_one_action", + "data_func": { + "name": "generate_cb_data", + "params": { + "num_examples": 10, + "num_features": 1, + "seed": 1, + "action_range": [ + 0, + 1 + ], + "reward_function": { + "name": "fixed_reward", + "params": {} + }, + "logging_policy": { + "name": "even_probability" + } + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": [ + 0, + 1 + ], + "threshold": 0.1 + } + } + ], + "grids": { + "g0": { + "#base": [ + "--cats 2 --min_value 0 --max_value 1 --bandwidth 1" + ] + }, + "g1": { + "--cb_type": [ + "ips", + "mtr", + "dr", + "dm" + ] + } + }, + "grids_expression": "g0 * g1", + "output": [ + "--readable_model", + "-p" + ] + } +] \ No newline at end of file diff --git a/python/tests/e2e_v2/test_configs/classification.json b/python/tests/e2e_v2/test_configs/classification.json new file mode 100644 index 00000000000..738e1981eda --- /dev/null +++ b/python/tests/e2e_v2/test_configs/classification.json @@ -0,0 +1,102 @@ +[ + { + "test_name": "binary_class", + "data_func": { + "name": "generate_classification_data", + "params": { + "num_example": 2000, + "num_features": 1, + "seed": 1, + "classify_func": { + "name": "binary_classification_one_feature", + "params": {} + }, + "bounds": [ + [ + 0, + 1 + ] + ] + } + }, + "assert_functions": [ + { + "name": "assert_prediction_with_generated_data", + "params": { + "data_func": { + "name": "generate_classification_data", + "params": { + "num_example": 100, + "num_features": 1, + "classify_func": { + "name": "binary_classification_one_feature", + "params": {} + }, + "seed": 1 + } + }, + "accuracy_threshold": 0.9 + } + } + ], + "grids": { + "g0": { + "#base": [ + "--oaa 3" + ] + } + }, + "grids_expression": "g0", + "output": [ + "--readable_model", + "-p" + ] + }, + { + "test_name": "multiclass_two_features", + "data_func": { + "name": "generate_classification_data", + "params": { + "num_example": 100000, + "num_features": 2, + "classify_func": { + "name": "multi_classification_two_features", + "params": {} + }, + "seed": 1 + } + }, + "assert_functions": [ + { + "name": "assert_prediction_with_generated_data", + "params": { + "data_func": { + "name": "generate_classification_data", + "params": { + "num_example": 500, + "num_features": 2, + "seed": 1, + "classify_func": { + "name": "multi_classification_two_features", + "params": {} + } + } + }, + "accuracy_threshold": 0.5 + } + } + ], + "grids": { + "g0": { + "#base": [ + "--oaa 25" + ] + } + }, + "grids_expression": "g0", + "output": [ + "--readable_model", + "-p" + ] + } +] \ No newline at end of file diff --git a/python/tests/e2e_v2/test_configs/regression.json b/python/tests/e2e_v2/test_configs/regression.json new file mode 100644 index 00000000000..ae4a5bcc973 --- /dev/null +++ b/python/tests/e2e_v2/test_configs/regression.json @@ -0,0 +1,107 @@ +[ + { + "data_func": { + "name": "constant_function", + "params": { + "no_sample": 2000, + "constant": 5, + "x_lower_bound": 1, + "x_upper_bound": 100, + "seed" : 1 + } + }, + "assert_functions": [ + { + "name": "assert_prediction", + "params": { + "expected_value": [ + 5 + ], + "threshold": 0.8 + } + }, + { + "name": "assert_weight", + "params": { + "expected_weights": { + "f^x": 0, + "Constant": 5 + }, + "atol": 0.1, + "rtol": 1 + } + } + ], + "grids": { + "g0": { + "#base": [ + "-P 50000 --preserve_performance_counters --save_resume" + ] + }, + "g1": { + "--learning_rate": [ + null, + 0.1, + 0.01, + 0.001 + ], + "--decay_learning_rate": [ + null, + 1.1, + 1, + 0.9 + ], + "--power_t": [ + null, + 0.5, + 0.6, + 0.4 + ] + }, + "g2": { + "#reg": [ + "--freegrad", + "--conjugate_gradient", + "--bfgs --passes 1 --cache" + ] + }, + "g3": { + "#reg": [ + "--ftrl", + "--coin", + "--pistol" + ], + "--ftrl_alpha": [ + null, + 0.1 + ], + "--ftrl_beta": [ + null, + 0.1 + ] + }, + "g4": { + "--loss_function": [ + null, + "poisson", + "quantile" + ] + }, + "g5": { + "--loss_function": [ + "expectile" + ], + "--expectile_q": [ + 0.25, + 0.5 + ] + } + }, + "grids_expression": "g0 * (g1 + g2 + g3) * (g5 + g4)", + "output": [ + "--readable_model", + "--invert_hash", + "-p" + ] + } +] \ No newline at end of file diff --git a/python/tests/e2e_v2/test_configs/slate.json b/python/tests/e2e_v2/test_configs/slate.json new file mode 100644 index 00000000000..670b2992f40 --- /dev/null +++ b/python/tests/e2e_v2/test_configs/slate.json @@ -0,0 +1,145 @@ +[ + { + "test_name": "slates", + "data_func": { + "name": "generate_slate_data", + "params": { + "num_examples": 1000, + "seed" : 1, + "reward_function": { + "name": "reverse_reward_after_threshold", + "params": { + "reward": [ + [ + 1, + 0 + ], + [ + 0, + 1 + ] + ], + "threshold": 500 + } + }, + "logging_policy": { + "name": "even_probability", + "params": {} + }, + "action_space": { + "name": "new_action_after_threshold", + "params": { + "threshold": 500, + "before": [ + [ + "longshirt", + "tshirt" + ], + [ + "shorts", + "jeans" + ] + ], + "after": [ + [ + "rainshirt", + "buttonupshirt" + ], + [ + "formalpants", + "rainpants" + ] + ] + } + } + } + }, + "assert_functions": [ + { + "name": "assert_loss", + "params": { + "expected_loss": -1.9, + "decimal": 0.1 + } + }, + { + "name": "assert_prediction", + "params": { + "expected_value": [ + [ + 0.1, + 0.9 + ], + [ + 0.9, + 0.1 + ] + ], + "threshold": 0.8, + "atol": 0.01, + "rtol": 0.01 + } + } + ], + "grids": { + "slate": { + "#base": [ + "--slates" + ] + }, + "epsilon": { + "--epsilon": [ + 0.1, + 0.2, + 0.3 + ] + }, + "first": { + "--first": [ + 1, + 2 + ] + }, + "bag": { + "--bag": [ + 5, + 6, + 7 + ] + }, + "cover": { + "--cover": [ + 1, + 2, + 3 + ] + }, + "squarecb": { + "--squarecb": [ + "--gamma_scale 1000", + "--gamma_scale 10000" + ] + }, + "synthcover": { + "--synthcover": [ + "" + ] + }, + "regcb": { + "--regcb": [ + "" + ] + }, + "softmax": { + "--softmax": [ + "" + ] + } + }, + "grids_expression": "slate * (epsilon + first + bag + cover + squarecb + synthcover + regcb + softmax)", + "output": [ + "--readable_model", + "-p" + ] + } +] \ No newline at end of file diff --git a/python/tests/e2e_v2/test_core.py b/python/tests/e2e_v2/test_core.py new file mode 100644 index 00000000000..26d071de838 --- /dev/null +++ b/python/tests/e2e_v2/test_core.py @@ -0,0 +1,129 @@ +from vw_executor.vw import Vw +from vw_executor.vw_opts import Grid +import pytest +import os +import logging +from test_helper import ( + json_to_dict_list, + evaluate_expression, + copy_file, + custom_sort, + get_function_obj_with_dirs, + datagen_driver, +) +from conftest import STORE_OUTPUT + +CURR_DICT = os.path.dirname(os.path.abspath(__file__)) +TEST_CONFIG_FILES_NAME = os.listdir(os.path.join(CURR_DICT, "test_configs")) +TEST_CONFIG_FILES = [json_to_dict_list(i) for i in TEST_CONFIG_FILES_NAME] +GENERATED_TEST_CASES = [] +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +def cleanup_data_file(): + script_directory = os.path.dirname(os.path.realpath(__file__)) + # List all files in the directory + for name in TEST_CONFIG_FILES_NAME: + name = name.split(".")[0] + try: + files = os.listdir(os.path.join(script_directory, name)) + except: + return + # Iterate over the files and remove the ones with .txt extension + for file in files: + if file.endswith(".txt"): + file_path = os.path.join(script_directory + "/" + name, file) + os.remove(file_path) + + +@pytest.fixture +def test_descriptions(request): + resource = request.param + yield resource + cleanup_data_file() + + +def core_test(files, grid, outputs, job_assert, job_assert_args): + vw = Vw(CURR_DICT + "/.vw_cache", reset=True, handler=None) + result = vw.train(files, grid, outputs) + for j in result: + test_name = ( + job_assert.__name__ + + "_" + + "_".join("".join([i for i in str(j.opts) if i != "-"]).split(" ")) + ) + GENERATED_TEST_CASES.append( + [lambda: job_assert(j, **job_assert_args), test_name] + ) + if STORE_OUTPUT: + if not os.path.exists(CURR_DICT + "/output"): + os.mkdir(CURR_DICT + "/output") + if not os.path.exists(CURR_DICT + "/output/" + test_name): + os.mkdir(CURR_DICT + "/output/" + test_name) + fileName = str(list(j.outputs.values())[0][0]).split("/")[-1] + for key, value in list(j.outputs.items()): + copy_file( + value[0], + CURR_DICT + "/output/" + test_name + "/" + f"{key}_" + fileName, + ) + copy_file( + os.path.join(j.cache.path, "cacheNone/" + fileName), + CURR_DICT + "/output/" + test_name + "/" + fileName, + ) + + +def get_options(grids, expression): + final_variables = {} + for key in grids: + final_variables[key] = Grid(grids[key]) + return evaluate_expression(expression, final_variables) + + +@pytest.mark.usefixtures("test_descriptions", TEST_CONFIG_FILES) +def init_all(test_descriptions): + for tIndex, tests in enumerate(test_descriptions): + task_folder = TEST_CONFIG_FILES_NAME[tIndex].split(".")[0] + package_name = [task_folder + ".", ""] + package_name = custom_sort(task_folder, package_name) + package_name.append(".") + if type(tests) is not list: + tests = [tests] + for test_description in tests: + options = get_options( + test_description["grids"], test_description["grids_expression"] + ) + data_func = get_function_obj_with_dirs( + package_name, + "data_generation", + test_description["data_func"]["name"], + ) + scenario_directory = ( + os.path.dirname(os.path.realpath(__file__)) + f"/{task_folder}" + ) + data = datagen_driver( + scenario_directory, data_func, **test_description["data_func"]["params"] + ) + script_directory = os.path.dirname(os.path.realpath(__file__)) + for assert_func in test_description["assert_functions"]: + assert_job = get_function_obj_with_dirs( + package_name, "assert_job", assert_func["name"] + ) + core_test( + os.path.join(script_directory, data), + options, + test_description["output"], + assert_job, + assert_func["params"], + ) + + +try: + init_all(TEST_CONFIG_FILES) + for generated_test_case in GENERATED_TEST_CASES: + test_name = f"test_{generated_test_case[1]}" + generated_test_case[0].__name__ = test_name + globals()[test_name] = generated_test_case[0] +finally: + cleanup_data_file() diff --git a/python/tests/e2e_v2/test_helper.py b/python/tests/e2e_v2/test_helper.py new file mode 100644 index 00000000000..b4fd9bd759c --- /dev/null +++ b/python/tests/e2e_v2/test_helper.py @@ -0,0 +1,140 @@ +import json +import importlib +import os +import itertools +import inspect +import shutil + +# Get the current directory +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def json_to_dict_list(file): + with open(current_dir + "/test_configs/" + file, "r") as file: + # Load the JSON data + return json.load(file) + + +def evaluate_expression(expression, variables): + # Create a dictionary to hold the variable values + variables_dict = {} + # Populate the variables_dict with the provided variables + for variable_name, variable_value in variables.items(): + variables_dict[variable_name] = variable_value + # Evaluate the expression using eval() + result = eval(expression, variables_dict) + return result + + +def dynamic_function_call(module_name, function_name, *args, **kwargs): + try: + calling_frame = inspect.stack()[1] + calling_module = inspect.getmodule(calling_frame[0]) + calling_package = calling_module.__package__ + module = importlib.import_module(module_name, package=calling_package) + function = getattr(module, function_name) + result = function(*args, **kwargs) + return result + except ImportError: + pass + except AttributeError: + pass + + +def get_function_object(module_name, function_name): + function = None + try: + calling_frame = inspect.stack()[1] + calling_module = inspect.getmodule(calling_frame[0]) + calling_package = calling_module.__package__ + module = importlib.import_module(module_name, package=calling_package) + function = getattr(module, function_name) + return function + except ImportError: + pass + except AttributeError: + pass + + +def generate_string_combinations(*lists): + combinations = list(itertools.product(*lists)) + combinations = ["".join(combination) for combination in combinations] + return combinations + + +def copy_file(source_file, destination_file): + try: + shutil.copy(source_file, destination_file) + print(f"File copied successfully from '{source_file}' to '{destination_file}'.") + except FileNotFoundError: + print(f"Source file '{source_file}' not found.") + except PermissionError: + print( + f"Permission denied. Unable to copy '{source_file}' to '{destination_file}'." + ) + + +def call_function_with_dirs(dirs, module_name, function_name, **kargs): + for dir in dirs: + try: + data = dynamic_function_call( + dir + module_name, + function_name, + **kargs, + ) + if data: + return data + except Exception as error: + if type(error) not in [ModuleNotFoundError]: + raise error + + +def get_function_obj_with_dirs(dirs, module_name, function_name): + obj = None + for dir in dirs: + try: + obj = get_function_object( + dir + module_name, + function_name, + ) + if obj: + return obj + except Exception as error: + if type(error) not in [ModuleNotFoundError]: + raise error + if not obj: + raise ModuleNotFoundError( + f"Module '{module_name}' not found in any of the directories {dirs}." + ) + + +def datagen_driver(script_directory, impl, **kwargs): + names = [] + for i in kwargs.values(): + if type(i) == dict: + names.append(list(i.items())[0][1]) + elif type(i) == list: + pass + else: + names.append(i) + + dataFile = f"{str(impl.__name__)}_{'_'.join([str(i) for i in names])}.txt" + with open(os.path.join(script_directory, dataFile), "w") as f: + impl(f, **kwargs) + return os.path.join(script_directory, dataFile) + + +def calculate_similarity(word, string): + # Calculate the similarity score between the string and the word + score = 0 + for char in word: + if char in string: + score += 1 + return score + + +def custom_sort(word, strings): + # Sort the list of strings based on their similarity to the word + return sorted( + strings, key=lambda string: calculate_similarity(word, string), reverse=True + )