Skip to content

Commit

Permalink
fix test locally
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Oct 2, 2024
1 parent 1ed7ca5 commit ddcfeba
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 30 deletions.
3 changes: 2 additions & 1 deletion mighty/mighty_utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def make_carl_env(
) -> Tuple[type[CARLVectorEnvSimulator], Callable, int]:
"""Make carl environment."""
import carl
from carl import envs
from carl.context.sampler import ContextSampler

env_kwargs = OmegaConf.to_container(cfg.env_kwargs, resolve=True)
Expand All @@ -105,7 +106,7 @@ def make_carl_env(
if "evaluation_context_sample_seed" not in env_kwargs: # type: ignore
env_kwargs["evaluation_context_sample_seed"] = 1 # type: ignore

env_class = getattr(carl.envs, cfg.env)
env_class = getattr(envs, cfg.env)

if len(env_kwargs["context_feature_args"].keys()) > 0: # type: ignore
context_distributions = []
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,18 @@ dependencies = [
"hydra-colorlog~=1.2",
"hydra-submitit-launcher~=1.2",
"pandas",
"scipy~=1.8",
"scipy==1.12",
"rich~=12.4",
"wandb~=0.12",
"torch",
"dill",
"imageio",
"jax==0.4.20",
"jaxlib==0.4.20",
"evosax==0.1.6"
]

[project.optional-dependencies]
dev = ["ruff", "mypy", "automl-sphinx-theme==0.2.0", "build", "pytest", "pytest-cov"]
carl = ["carl_bench==1.1.0"]
carl = ["carl_bench==1.1.0", "brax==0.9.3", "protobuf>=3.17.3", "mujoco==3.0.1"]
dacbench = ["dacbench>=0.3.0", "torchvision", "ioh"]
pufferlib = ["pufferlib==1.0.0"]

Expand Down
42 changes: 17 additions & 25 deletions test/test_env_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TestEnvCreation:
)
dacbench_config_benchmark = OmegaConf.create(
{
"env": "SigmoidBenchmark",
"env": "FunctionApproximationBenchmark",
"env_kwargs": {
"benchmark": True,
"dimension": 1,
Expand All @@ -59,12 +59,10 @@ class TestEnvCreation:
)
dacbench_config = OmegaConf.create(
{
"env": "SigmoidBenchmark",
"env": "FunctionApproximationBenchmark",
"env_kwargs": {
"instance_set_path": "../instance_sets/sigmoid/sigmoid_1D3M_train.csv",
"test_set_path": "../instance_sets/sigmoid/sigmoid_1D3M_train.csv",
"action_space_args": [3],
"action_values": [3],
"instance_set_path": "sigmoid_2D3M_train.csv",
"test_set_path": "sigmoid_2D3M_train.csv",
},
"env_wrappers": ["mighty.utils.wrappers.MultiDiscreteActionWrapper"],
"num_envs": 1,
Expand Down Expand Up @@ -210,22 +208,6 @@ def test_make_dacbench_env(self):
== self.dacbench_config.env_kwargs.instance_set_path
), "Environment should have correct instance set."
assert eval_env().envs[0].test, "Eval environment should be in test mode."
assert (
env.envs[0].config.action_space_args
== self.dacbench_config.env_kwargs.action_space_args
), "Environment should have correct action space args."
assert (
eval_env().envs[0].config.action_space_args
== self.dacbench_config.env_kwargs.action_space_args
), "Eval environment should have correct action space args."
assert (
env.envs[0].config.action_values
== self.dacbench_config.env_kwargs.action_values
), "Environment should have correct action values."
assert (
eval_env().envs[0].config.action_values
== self.dacbench_config.env_kwargs.action_values
), "Eval environment should have correct action values."

def test_make_dacbench_benchmark_mode(self):
"""Test env creation with make_dacbench_env in benchmark mode."""
Expand Down Expand Up @@ -270,9 +252,19 @@ def test_make_dacbench_benchmark_mode(self):
for k in env.envs[0].config.keys():
if k == "observation_space_args":
continue
assert (
env.envs[0].config[k] == benchmark_env.config[k]
), f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}"
elif k == "instance_set" or k == "test_set":
for i in range(len(env.envs[0].config[k])):
assert (
env.envs[0].config[k][i].functions[0].a == benchmark_env.config[k][i].functions[0].a
), f"Environment should have matching instances, mismatch for function parameter a at instance {i}: {env.envs[0].config[k][i].functions[0].a} != {benchmark_env.config[k][i].functions[0].a}"
assert (
env.envs[0].config[k][i].functions[0].b == benchmark_env.config[k][i].functions[0].b
), f"Environment should have matching instances, mismatch for function parameter b at instance {i}: {env.envs[0].config[k][i].functions[0].b} != {benchmark_env.config[k][i].functions[0].b}"
assert (
env.envs[0].config[k][i].omit_instance_type == benchmark_env.config[k][i].omit_instance_type
), f"Environment should have matching instances, mismatch for omit_instance_type at instance {i}: {env.envs[0].config[k][i].omit_instance_type} != {benchmark_env.config[k][i].omit_instance_type}"
else:
assert (env.envs[0].config[k] == benchmark_env.config[k]), f"Environment should have correct config, mismatch at {k}: {env.envs[0].config[k]} != {benchmark_env.config[k]}"

def test_make_carl_env(self):
"""Test env creation with make_carl_env."""
Expand Down

0 comments on commit ddcfeba

Please sign in to comment.