Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update branch with dev changes #15

Open
wants to merge 20 commits into
base: konsti_monkey_experiments
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Black

on: [push, pull_request]

jobs:
black:
runs-on: ubuntu-18.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: psf/black@master
24 changes: 24 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Test

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

jobs:
test:
runs-on: ubuntu-18.04
steps:
- uses: actions/checkout@v2
- uses: azure/docker-login@v1
with:
login-server: index.docker.io
username: ${{ secrets.DOCKERIO_USERNAME }}
password: ${{ secrets.DOCKERIO_PASSWORD }}
- name: Add SSH key to enable cloning of private repos
run: mkdir .ssh && echo "${{ secrets.SSH_KEY }}" >> ".ssh/id_rsa"
- name: Build Docker image
run: docker build -t mei .
- name: Run tests
run: docker run --entrypoint pytest mei
17 changes: 0 additions & 17 deletions .travis.yml

This file was deleted.

154 changes: 114 additions & 40 deletions README.md

Large diffs are not rendered by default.

Binary file removed id_rsa.enc
Binary file not shown.
10 changes: 5 additions & 5 deletions mei/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class Member(mixins.TrainedEnsembleModelTemplateMixin.Member, dj.Part):
"""Member table template."""


class CSRFV1SelectorTemplate(mixins.CSRFV1SelectorTemplateMixin, dj.Computed):
"""CSRF V1 selector table template.
class CSRFV1ObjectiveTemplate(mixins.CSRFV1ObjectiveTemplateMixin, dj.Computed):
"""CSRF V1 objective table template.

To create a functional "CSRFV1Selector" table, create a new class that inherits from this template and decorate it
To create a functional "CSRFV1Objective" table, create a new class that inherits from this template and decorate it
with your preferred Datajoint schema. By default, the created table will point to the "Dataset" table in the
Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting the class attribute called
"dataset_table".
Expand All @@ -52,8 +52,8 @@ class MEITemplate(mixins.MEITemplateMixin, dj.Computed):
"""MEI table template.

To create a functional "MEI" table, create a new class that inherits from this template and decorate it with your
preferred Datajoint schema. Next assign your trained model (or trained ensemble model) and your selector table to
the class variables called "trained_model_table" and "selector_table". By default, the created table will point to
preferred Datajoint schema. Next assign your trained model (or trained ensemble model) and your objective table to
the class variables called "trained_model_table" and "objective_table". By default, the created table will point to
the "MEIMethod" table in the Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting
the class attribute called "method_table".
"""
Expand Down
2 changes: 1 addition & 1 deletion mei/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _load_ensemble_model(
)


class CSRFV1SelectorTemplateMixin:
class CSRFV1ObjectiveTemplateMixin:
definition = """
# contains information that can be used to map a neuron's id to its corresponding integer position in the output of
# the model.
Expand Down
5 changes: 3 additions & 2 deletions mei/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import Tensor
from torch.nn import Module, ModuleList
from collections.abc import Iterable


class EnsembleModel(Module):
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
if target_fn is None:
target_fn = lambda x: x
self.model = model
self.constraint = constraint
self.constraint = constraint if (isinstance(constraint, Iterable) or constraint is None) else [constraint]
self.forward_kwargs = forward_kwargs if forward_kwargs else dict()
self.target_fn = target_fn

Expand All @@ -84,7 +85,7 @@ def __call__(self, x: Tensor, *args, **kwargs) -> Tensor:
output = self.model(x, *args, **self.forward_kwargs, **kwargs)
return (
self.target_fn(output)
if (not self.constraint and self.constraint != 0)
if self.constraint is None or len(self.constraint) == 0
else self.target_fn(output[:, self.constraint])
)

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 120
56 changes: 28 additions & 28 deletions tests/unit/mixins/test_csrfv1_selector_template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@


@pytest.fixture
def selector_template(dataset_table):
selector_template = mixins.CSRFV1SelectorTemplateMixin
selector_template.dataset_table = dataset_table
selector_template.dataset_fn = "dataset_fn"
return selector_template
def objective_template(dataset_table):
objective_template = mixins.CSRFV1ObjectiveTemplateMixin
objective_template.dataset_table = dataset_table
objective_template.dataset_fn = "dataset_fn"
return objective_template


@pytest.fixture
def dataset_table():
return MagicMock(name="dataset_table")


def test_if_key_source_is_correct(selector_template, dataset_table):
def test_if_key_source_is_correct(objective_template, dataset_table):
dataset_table.return_value.__and__.return_value = "key_source"
assert selector_template()._key_source == "key_source"
assert objective_template()._key_source == "key_source"
dataset_table.return_value.__and__.assert_called_once_with(dict(dataset_fn="dataset_fn"))


class TestMake:
@pytest.fixture
def selector_template(self, selector_template, dataset_table, insert):
selector_template.insert = insert
return selector_template
def objective_template(self, objective_template, dataset_table, insert):
objective_template.insert = insert
return objective_template

@pytest.fixture
def dataset_table(self, dataset_table):
Expand All @@ -43,26 +43,26 @@ def insert(self):
def get_mappings(self):
return MagicMock(return_value="mappings")

def test_if_dataset_config_is_correctly_fetched(self, key, selector_template, dataset_table, get_mappings):
selector_template().make(key, get_mappings=get_mappings)
def test_if_dataset_config_is_correctly_fetched(self, key, objective_template, dataset_table, get_mappings):
objective_template().make(key, get_mappings=get_mappings)
dataset_table.return_value.__and__.assert_called_once_with(key)
dataset_table.return_value.__and__.return_value.fetch1.assert_called_once_with("dataset_config")

def test_if_get_mappings_is_correctly_called(self, key, selector_template, get_mappings):
selector_template().make(key, get_mappings=get_mappings)
def test_if_get_mappings_is_correctly_called(self, key, objective_template, get_mappings):
objective_template().make(key, get_mappings=get_mappings)
get_mappings.assert_called_once_with("dataset_config", key)

def test_if_mappings_are_correctly_inserted(self, key, selector_template, insert, get_mappings):
selector_template().make(key, get_mappings=get_mappings)
def test_if_mappings_are_correctly_inserted(self, key, objective_template, insert, get_mappings):
objective_template().make(key, get_mappings=get_mappings)
insert.assert_called_once_with("mappings")


class TestGetOutputSelectedModel:
class TestGetObjective:
@pytest.fixture
def selector_template(self, selector_template, constrained_output_model, magic_and):
selector_template.constrained_output_model = constrained_output_model
selector_template.__and__ = magic_and
return selector_template
def objective_template(self, objective_template, constrained_output_model, magic_and):
objective_template.constrained_output_model = constrained_output_model
objective_template.__and__ = magic_and
return objective_template

@pytest.fixture
def constrained_output_model(self):
Expand All @@ -78,19 +78,19 @@ def magic_and(self):
def model(self):
return MagicMock(name="model")

def test_if_neuron_position_and_session_id_are_correctly_fetched(self, key, model, selector_template, magic_and):
selector_template().get_output_selected_model(model, key)
def test_if_neuron_position_and_session_id_are_correctly_fetched(self, key, model, objective_template, magic_and):
objective_template().get_objective(model, key)
magic_and.assert_called_once_with(key)
magic_and.return_value.fetch1.assert_called_once_with("neuron_position", "session_id")

def test_if_constrained_output_model_is_correctly_initialized(
self, key, model, selector_template, constrained_output_model
self, key, model, objective_template, constrained_output_model
):
selector_template().get_output_selected_model(model, key)
objective_template().get_objective(model, key)
constrained_output_model.assert_called_once_with(
model, "neuron_pos", forward_kwargs=dict(data_key="session_id")
)

def test_if_output_selected_model_is_correctly_returned(self, key, model, selector_template):
output_selected_model = selector_template().get_output_selected_model(model, key)
assert output_selected_model == "constrained_output_model"
def test_if_objective_is_correctly_returned(self, key, model, objective_template):
objective = objective_template().get_objective(model, key)
assert objective == "constrained_output_model"
20 changes: 9 additions & 11 deletions tests/unit/mixins/test_mei_template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_if_model_loader_is_correctly_initialized(mei_template, trained_model_ta

class TestMake:
@pytest.fixture
def mei_template(self, mei_template, selector_table, method_table, seed_table, insert1, save, model_loader_class):
mei_template.selector_table = selector_table
def mei_template(self, mei_template, objective_table, method_table, seed_table, insert1, save, model_loader_class):
mei_template.objective_table = objective_table
mei_template.method_table = method_table
mei_template.seed_table = seed_table
mei_template.insert1 = insert1
Expand All @@ -53,10 +53,10 @@ def mei_template(self, mei_template, selector_table, method_table, seed_table, i
return mei_template

@pytest.fixture
def selector_table(self):
selector_table = MagicMock(name="selector_table")
selector_table.return_value.get_output_selected_model.return_value = "output_selected_model"
return selector_table
def objective_table(self):
objective_table = MagicMock(name="objective_table")
objective_table.return_value.get_objective.return_value = "objective"
return objective_table

@pytest.fixture
def method_table(self):
Expand All @@ -78,9 +78,9 @@ def test_if_model_is_correctly_loaded(self, key, mei_template, model_loader):
mei_template().make(key)
model_loader.load.assert_called_once_with(key=key)

def test_if_correct_model_output_is_selected(self, key, mei_template, selector_table):
def test_if_get_objective_is_correctly_called(self, key, mei_template, objective_table):
mei_template().make(key)
selector_table.return_value.get_output_selected_model.assert_called_once_with("model", key)
objective_table.return_value.get_objective.assert_called_once_with("model", key)

def test_if_seed_is_correctly_fetched(self, key, mei_template, seed_table):
mei_template().make(key)
Expand All @@ -89,9 +89,7 @@ def test_if_seed_is_correctly_fetched(self, key, mei_template, seed_table):

def test_if_mei_is_correctly_generated(self, key, mei_template, method_table):
mei_template().make(key)
method_table.return_value.generate_mei.assert_called_once_with(
"dataloaders", "output_selected_model", key, "seed"
)
method_table.return_value.generate_mei.assert_called_once_with("dataloaders", "objective", key, "seed")

def test_if_mei_is_correctly_saved(self, key, mei_template, save):
mei_template().make(key)
Expand Down