Skip to content

Commit

Permalink
Add test to check output shapes of batchnormalization operations
Browse files Browse the repository at this point in the history
  • Loading branch information
sahmed authored and sahmed committed Sep 30, 2023
1 parent 515d355 commit c877b07
Showing 1 changed file with 62 additions and 19 deletions.
81 changes: 62 additions & 19 deletions tests/test_snntorch/test_bntt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,79 @@
import torch


# @pytest.fixture(scope="module")
# @pytest.mark.parametrize("value", params=[1, 2, 3])
# def time_steps_(value):
# return value
#
#
# @pytest.fixture(scope="module")
# def input_():
# # 2 time_steps, 2 batch_size, 2 features
# return torch.rand(2, 2, 2)
#
#
@pytest.fixture(scope="module")
def input1d_():
# 2 time_steps, 2, batch size, 2 features
return torch.rand(2, 2, 4)


@pytest.fixture(scope="module")
def batchnormtt1d_instance():
return snn.BatchNormTT1d(2, time_steps_)
return snn.BatchNormTT1d(4, 2)


class TestBatchNormTT1d:
@pytest.mark.parametrize("input_features, time_steps", ([1, 1], [2, 3], [3, 6]))
def test_batchnormtt1d(
@pytest.mark.parametrize("time_steps, num_features", ([1, 1], [3, 2], [6, 3]))
def test_batchnormtt1d_init(
self,
input_features,
time_steps
time_steps,
num_features
):
batchnormtt1d_instance = snn.BatchNormTT1d(input_features, time_steps)
batchnormtt1d_instance = snn.BatchNormTT1d(num_features, time_steps)

assert len(batchnormtt1d_instance) == time_steps
for module in batchnormtt1d_instance:
assert module.num_features == input_features
assert module.num_features == num_features
assert module.eps == 1e-5
assert module.momentum == 0.1
assert module.affine
assert module.bias is None

def test_batchnormtt1d_output(
self,
batchnormtt1d_instance,
input1d_
):
for step, batchnormtt1d_module in enumerate(batchnormtt1d_instance):
out = batchnormtt1d_module(input1d_[step])

assert out.shape == input1d_[step].shape


@pytest.fixture(scope="module")
def input2d_():
# 2 time_steps, 2, batch size, 2 features
return torch.rand(2, 2, 4, 2, 2)


@pytest.fixture(scope="module")
def batchnormtt2d_instance():
return snn.BatchNormTT2d(4, 2)


class TestBatchNormTT2d:
@pytest.mark.parametrize("time_steps, num_features", ([1, 1], [3, 2], [6, 3]))
def test_batchnormtt2d_init(
self,
time_steps,
num_features
):
batchnormtt2d_instance = snn.BatchNormTT2d(num_features, time_steps)

assert len(batchnormtt2d_instance) == time_steps
for module in batchnormtt2d_instance:
assert module.num_features == num_features
assert module.eps == 1e-5
assert module.momentum == 0.1
assert module.affine
assert module.bias is None

def test_batchnormtt1d_output(
self,
batchnormtt2d_instance,
input2d_
):
for step, batchnormtt2d_module in enumerate(batchnormtt2d_instance):
out = batchnormtt2d_module(input2d_[step])

assert out.shape == input2d_[step].shape

0 comments on commit c877b07

Please sign in to comment.