Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 23, 2024
1 parent 90e7ec3 commit 643d6aa
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 32 deletions.
6 changes: 0 additions & 6 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import contextlib
from typing import List

import packaging
import pytest
import torch
import torch_geometric.nn
import torchrl

from benchmarl.hydra_config import load_model_config_from_hydra
from benchmarl.models import GnnConfig, model_config_registry
Expand Down Expand Up @@ -183,8 +181,6 @@ def test_models_forward_shape(
share_params=share_params,
n_agents=n_agents,
)
if packaging.version.parse(torchrl.__version__).local is None and config.is_rnn:
pytest.skip("rnn model needs torchrl from github")

if centralised:
config.is_critic = True
Expand Down Expand Up @@ -275,8 +271,6 @@ def test_share_params_between_models(
config = model_config_registry[model_name].get_from_yaml()
if centralised:
config.is_critic = True
if packaging.version.parse(torchrl.__version__).local is None and config.is_rnn:
pytest.skip("rnn model needs torchrl from github")
model = config.get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down
10 changes: 0 additions & 10 deletions test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
#


import packaging
import pytest
import torchrl
from benchmarl.algorithms import (
algorithm_config_registry,
IddpgConfig,
Expand Down Expand Up @@ -111,10 +109,6 @@ def test_gnn(
"algo_config", [IddpgConfig, MappoConfig, QmixConfig, MasacConfig]
)
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).local is None,
reason="gru model needs torchrl from github",
)
def test_gru(
self,
algo_config: AlgorithmConfig,
Expand All @@ -141,10 +135,6 @@ def test_gru(
"algo_config", [MaddpgConfig, IppoConfig, QmixConfig, IsacConfig]
)
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).local is None,
reason="lstm model needs torchrl from github",
)
def test_lstm(
self,
algo_config: AlgorithmConfig,
Expand Down
6 changes: 0 additions & 6 deletions test/test_smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# LICENSE file in the root directory of this source tree.
#

import packaging
import pytest
import torchrl

from benchmarl.algorithms import algorithm_config_registry, MappoConfig, QmixConfig
from benchmarl.algorithms.common import AlgorithmConfig
Expand Down Expand Up @@ -81,10 +79,6 @@ def test_gnn(

@pytest.mark.parametrize("algo_config", [QmixConfig])
@pytest.mark.parametrize("task", [Smacv2Task.PROTOSS_5_VS_5])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).local is None,
reason="gru model needs torchrl from github",
)
def test_gru(
self,
algo_config,
Expand Down
10 changes: 0 additions & 10 deletions test/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# LICENSE file in the root directory of this source tree.
#

import packaging
import pytest
import torchrl
from benchmarl.algorithms import (
algorithm_config_registry,
IddpgConfig,
Expand Down Expand Up @@ -119,10 +117,6 @@ def test_gnn(
"algo_config", [MaddpgConfig, IppoConfig, QmixConfig, MasacConfig]
)
@pytest.mark.parametrize("task", [VmasTask.NAVIGATION])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).local is None,
reason="gru model needs torchrl from github",
)
def test_gru(
self,
algo_config: AlgorithmConfig,
Expand Down Expand Up @@ -150,10 +144,6 @@ def test_gru(
"algo_config", [IddpgConfig, MappoConfig, QmixConfig, IsacConfig]
)
@pytest.mark.parametrize("task", [VmasTask.NAVIGATION])
@pytest.mark.skipif(
packaging.version.parse(torchrl.__version__).local is None,
reason="lstm model needs torchrl from github",
)
def test_lstm(
self,
algo_config: AlgorithmConfig,
Expand Down

0 comments on commit 643d6aa

Please sign in to comment.