From ddcfeba842afaf052a6b1b7493a7a902a34dcf47 Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Wed, 2 Oct 2024 11:31:37 +0200 Subject: [PATCH] fix test locally --- mighty/mighty_utils/envs.py | 3 ++- pyproject.toml | 6 ++---- test/test_env_creation.py | 42 +++++++++++++++---------------------- 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/mighty/mighty_utils/envs.py b/mighty/mighty_utils/envs.py index 61e6105..e3fff3b 100644 --- a/mighty/mighty_utils/envs.py +++ b/mighty/mighty_utils/envs.py @@ -90,6 +90,7 @@ def make_carl_env( ) -> Tuple[type[CARLVectorEnvSimulator], Callable, int]: """Make carl environment.""" import carl + from carl import envs from carl.context.sampler import ContextSampler env_kwargs = OmegaConf.to_container(cfg.env_kwargs, resolve=True) @@ -105,7 +106,7 @@ def make_carl_env( if "evaluation_context_sample_seed" not in env_kwargs: # type: ignore env_kwargs["evaluation_context_sample_seed"] = 1 # type: ignore - env_class = getattr(carl.envs, cfg.env) + env_class = getattr(envs, cfg.env) if len(env_kwargs["context_feature_args"].keys()) > 0: # type: ignore context_distributions = [] diff --git a/pyproject.toml b/pyproject.toml index 0407e97..86464dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,20 +35,18 @@ dependencies = [ "hydra-colorlog~=1.2", "hydra-submitit-launcher~=1.2", "pandas", - "scipy~=1.8", + "scipy==1.12", "rich~=12.4", "wandb~=0.12", "torch", "dill", "imageio", - "jax==0.4.20", - "jaxlib==0.4.20", "evosax==0.1.6" ] [project.optional-dependencies] dev = ["ruff", "mypy", "automl-sphinx-theme==0.2.0", "build", "pytest", "pytest-cov"] -carl = ["carl_bench==1.1.0"] +carl = ["carl_bench==1.1.0", "brax==0.9.3", "protobuf>=3.17.3", "mujoco==3.0.1"] dacbench = ["dacbench>=0.3.0", "torchvision", "ioh"] pufferlib = ["pufferlib==1.0.0"] diff --git a/test/test_env_creation.py b/test/test_env_creation.py index 433bf3b..cf15dfe 100644 --- a/test/test_env_creation.py +++ b/test/test_env_creation.py @@ -46,7 +46,7 @@ class TestEnvCreation: ) dacbench_config_benchmark = OmegaConf.create( { - "env": "SigmoidBenchmark", + "env": "FunctionApproximationBenchmark", "env_kwargs": { "benchmark": True, "dimension": 1, @@ -59,12 +59,10 @@ class TestEnvCreation: ) dacbench_config = OmegaConf.create( { - "env": "SigmoidBenchmark", + "env": "FunctionApproximationBenchmark", "env_kwargs": { - "instance_set_path": "../instance_sets/sigmoid/sigmoid_1D3M_train.csv", - "test_set_path": "../instance_sets/sigmoid/sigmoid_1D3M_train.csv", - "action_space_args": [3], - "action_values": [3], + "instance_set_path": "sigmoid_2D3M_train.csv", + "test_set_path": "sigmoid_2D3M_train.csv", }, "env_wrappers": ["mighty.utils.wrappers.MultiDiscreteActionWrapper"], "num_envs": 1, @@ -210,22 +208,6 @@ def test_make_dacbench_env(self): == self.dacbench_config.env_kwargs.instance_set_path ), "Environment should have correct instance set." assert eval_env().envs[0].test, "Eval environment should be in test mode." - assert ( - env.envs[0].config.action_space_args - == self.dacbench_config.env_kwargs.action_space_args - ), "Environment should have correct action space args." - assert ( - eval_env().envs[0].config.action_space_args - == self.dacbench_config.env_kwargs.action_space_args - ), "Eval environment should have correct action space args." - assert ( - env.envs[0].config.action_values - == self.dacbench_config.env_kwargs.action_values - ), "Environment should have correct action values." - assert ( - eval_env().envs[0].config.action_values - == self.dacbench_config.env_kwargs.action_values - ), "Eval environment should have correct action values." def test_make_dacbench_benchmark_mode(self): """Test env creation with make_dacbench_env in benchmark mode.""" @@ -270,9 +252,19 @@ def test_make_dacbench_benchmark_mode(self): for k in env.envs[0].config.keys(): if k == "observation_space_args": continue - assert ( - env.envs[0].config[k] == benchmark_env.config[k] - ), f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}" + elif k == "instance_set" or k == "test_set": + for i in range(len(env.envs[0].config[k])): + assert ( + env.envs[0].config[k][i].functions[0].a == benchmark_env.config[k][i].functions[0].a + ), f"Environment should have matching instances, mismatch for function parameter a at instance {i}: {env.envs[0].config[k][i].functions[0].a} != {benchmark_env.config[k][i].functions[0].a}" + assert ( + env.envs[0].config[k][i].functions[0].b == benchmark_env.config[k][i].functions[0].b + ), f"Environment should have matching instances, mismatch for function parameter b at instance {i}: {env.envs[0].config[k][i].functions[0].b} != {benchmark_env.config[k][i].functions[0].b}" + assert ( + env.envs[0].config[k][i].omit_instance_type == benchmark_env.config[k][i].omit_instance_type + ), f"Environment should have matching instances, mismatch for omit_instance_type at instance {i}: {env.envs[0].config[k][i].omit_instance_type} != {benchmark_env.config[k][i].omit_instance_type}" + else: + assert (env.envs[0].config[k] == benchmark_env.config[k]), f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}" def test_make_carl_env(self): """Test env creation with make_carl_env."""