Skip to content

Commit

Permalink
Add tests for bntt1d with 3d data
Browse files Browse the repository at this point in the history
  • Loading branch information
sahmed authored and sahmed committed Sep 30, 2023
1 parent c877b07 commit be78319
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions tests/test_snntorch/test_bntt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,23 @@


@pytest.fixture(scope="module")
def input1d_():
# 2 time_steps, 2, batch size, 2 features
def input2d_():
# 2 time_steps, 2, batch size, 4 features
return torch.rand(2, 2, 4)


@pytest.fixture(scope="module")
def input3d_():
# 2 time_steps, 2, batch size, 4 features, 5 sequence length
return torch.rand(2, 2, 4, 5)


@pytest.fixture(scope="module")
def input4d_():
# 2 time_steps, 2, batch size, 4 features, height 2, width 2
return torch.rand(2, 2, 4, 2, 2)


@pytest.fixture(scope="module")
def batchnormtt1d_instance():
return snn.BatchNormTT1d(4, 2)
Expand All @@ -35,21 +47,25 @@ def test_batchnormtt1d_init(
assert module.affine
assert module.bias is None

def test_batchnormtt1d_output(
def test_batchnormtt1d_with_2d_input(
self,
batchnormtt1d_instance,
input1d_
input2d_
):
for step, batchnormtt1d_module in enumerate(batchnormtt1d_instance):
out = batchnormtt1d_module(input1d_[step])
out = batchnormtt1d_module(input2d_[step])

assert out.shape == input1d_[step].shape
assert out.shape == input2d_[step].shape

def test_batchnormtt1d_with_3d_input(
self,
batchnormtt1d_instance,
input3d_
):
for step, batchnormtt1d_module in enumerate(batchnormtt1d_instance):
out = batchnormtt1d_module(input3d_[step])

@pytest.fixture(scope="module")
def input2d_():
# 2 time_steps, 2, batch size, 2 features
return torch.rand(2, 2, 4, 2, 2)
assert out.shape == input3d_[step].shape


@pytest.fixture(scope="module")
Expand All @@ -74,12 +90,12 @@ def test_batchnormtt2d_init(
assert module.affine
assert module.bias is None

def test_batchnormtt1d_output(
def test_batchnormtt2d_with_4d_input(
self,
batchnormtt2d_instance,
input2d_
input4d_
):
for step, batchnormtt2d_module in enumerate(batchnormtt2d_instance):
out = batchnormtt2d_module(input2d_[step])
out = batchnormtt2d_module(input4d_[step])

assert out.shape == input2d_[step].shape
assert out.shape == input4d_[step].shape

0 comments on commit be78319

Please sign in to comment.