diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..2dfcd2d --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,11 @@ +name: Black + +on: [push, pull_request] + +jobs: + black: + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: psf/black@master diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..1b88b02 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,24 @@ +name: Test + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + test: + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v2 + - uses: azure/docker-login@v1 + with: + login-server: index.docker.io + username: ${{ secrets.DOCKERIO_USERNAME }} + password: ${{ secrets.DOCKERIO_PASSWORD }} + - name: Add SSH key to enable cloning of private repos + run: mkdir .ssh && echo "${{ secrets.SSH_KEY }}" >> ".ssh/id_rsa" + - name: Build Docker image + run: docker build -t mei . + - name: Run tests + run: docker run --entrypoint pytest mei diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 483e53d..0000000 --- a/.travis.yml +++ /dev/null @@ -1,17 +0,0 @@ -language: python - -services: - - docker - -before_install: - - mkdir .ssh - - openssl aes-256-cbc -K $encrypted_6e245bcd2662_key -iv $encrypted_6e245bcd2662_iv -in id_rsa.enc -out .ssh/id_rsa -d - - echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin - - docker build -t mei . - -script: - - docker run --entrypoint pytest mei - -notifications: - slack: - secure: LDnpN7Nmvj8iqwD0SvqvR+ooyklWsScsVp3FrlnTugeq/gcPdkn4clFWuX9sVVoXkHAgP2oxWfCURv4F3XoijiOcmLQfYyX2E3JPUjGelUDqzNvYw5bRcQp1g7hogQIwsMXlruRc1Jk4NhFOfVpk5+hfo8vqcbMRXVuddvF3SNqLl8NC1otmsW1y+ow64BiFE/waHIvRM9sXBjs3siY17oTGcY069rWNwjMr0Zt3e+mLgzkh7VwoxXaBBaU0Indsa5wIXtUd4YXnFCOUclxuzVb1E5Msdqtt3eUxR/QUm+zDwrK4nZY+faeTvxZps2EHVgUm9bwfKPMUKuSxgTdGgPfLX4Ay5u245qhdmvO4ixrWveXo2ux4OWy+qRLnTpUjvt6Gg/KTw4HuAuYLw1qXZ5Wgz5dUyQcL1dTBESXDbgmsj+PmpCa+v5O+hlx0CuDWrtztMZ9kJfUMOFMqROQYg6HpCOrxZIQXQ8xsES5SluqLyxxPjXDdh8WMCTqZBw4qOy59iPua8ZKp+yLispU5gj2XDqD14OlVwLT4CG2J7QL9+e8KekUI4ti3AT8TJwbgpNn5V/LAuH2oDE0N0bZDvQmR4ZEKxU9hAgA3BRZ1EZ9Ona0a04Z/TDX9gwkYeufHD2jWatgaGQ5uPpLX/nHcQZglgreuZQY62uJBra5i818= diff --git a/README.md b/README.md index 6f3b2f5..01aa644 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,16 @@ # MEI -Generate most exciting images (MEIs). +![Test](https://github.com/cblessing24/mei/workflows/Test/badge.svg) +![Black](https://github.com/cblessing24/mei/workflows/Black/badge.svg) + +Generate most exciting inputs (MEIs). ## Installation ## Usage -This section describes the general usage of the MEI framework. Due to the fact that this framework uses [DataJoint](https://github.com/datajoint/datajoint-python/) and [NNFabrik](https://github.com/sinzlab/nnfabrik) general familiarity with these two packages is assumed. +This section describes the general usage of the MEI framework. Due to the fact that this framework uses +[DataJoint](https://github.com/datajoint/datajoint-python/) and [NNFabrik](https://github.com/sinzlab/nnfabrik) general +familiarity with these two packages is assumed. ### 1. Table Setup @@ -13,7 +18,10 @@ This section describes how the tables used in the MEI generation process have to #### 1.1 Trained Ensemble Model Table -This table contains ensembles of previously trained models. During the MEI generation process, all models in an ensemble will be given the same input and their output will be averaged. The framework provides a template class that you can use to create a trained ensemble model table by declaring a new class that inherits from the template. Afterwards you have to link up your class with your NNFabrik-style dataset and trained model tables via class attributes. +This table contains ensembles of previously trained models. During the MEI generation process, all models in an ensemble +will be given the same input and their output will be averaged. The framework provides a template class that you can use +to create a trained ensemble model table by declaring a new class that inherits from the template. Afterwards you have +to link up your class with your NNFabrik-style dataset and trained model tables via class attributes. ##### Example @@ -29,23 +37,28 @@ class TrainedEnsembleModel(TrainedEnsembleModelTemplate): -#### 1.2 Selector Table +#### 1.2 Objective Table This table has two jobs: -1. Contain information that can be used to map the ID of a real neuron to the index of the corresponding unit in the model's output -2. Provide a method that can be used to constrain a model to a single output unit +1. Provide a method that can be used to get the to-be-optimized objective. +2. Contain the information that is needed to come up with the aforementioned objective. -Note that you will have to implement your own selector table because the exact implementation is heavily dependent on the structure of the data you want to use and the architecture of your models. +Note that you will have to implement your own objective table because the exact implementation is heavily dependent on +the structure of the data you want to use and the architecture of your models. ##### Example +The objective table implemented below contains information that can be used to map the ID of a real neuron (`neuron_id`) +to the index of its corresponding output unit (`output_unit`) in the model's output. The `get_objective` +method uses this information to constrain a given model to a single output unit and therefore to a single real neuron. + ```python from mei.modules import ConstrainedOutputModel @schema -class Selector(dj.Computed): +class Objective(dj.Computed): definition = """ -> self.dataset_table neuron_id: int @@ -58,16 +71,20 @@ class Selector(dj.Computed): def make(self, key): """Fills the table.""" - def get_output_selected_model(self, model, key): + def get_objective(self, model, key): output_unit = (self & key).fetch1("output_unit") return ConstrainedOutputModel(model, output_unit) ``` -Your implementation must provide a method called `get_output_selected_model` that has a PyTorch module (`model`) and a dictionary (`key`) as its only parameters and that must return a PyTorch module. The returned module must furthermore have its output constrained to a single unit. +Your implementation must provide a method called `get_objective` that has a PyTorch module (`model`) and a +dictionary (`key`) as its only parameters and that must return a PyTorch module representing the objective. The returned +module must itself return +a scalar value when called. #### 1.4 MEI Table -This table contains generated MEIs. You can create your own MEI table by inheriting from the provided template class. Afterwards you have to link up your table with your trained (ensemble) model and selector tables via class attributes. +This table contains generated MEIs. You can create your own MEI table by inheriting from the provided template class. +Afterwards you have to link up your table with your trained (ensemble) model and objective tables via class attributes. ##### Example @@ -78,7 +95,7 @@ from mei.main import MEITemplate @schema class MEI(MEITemplate): trained_model_table = TrainedEnsembleModel - selector_table = Selector + objective_table = Objective ``` ### 2. Generating MEIs @@ -89,7 +106,12 @@ This section lays out the general steps one would execute when generating MEIs. Note that this step is only required if you are using the trained ensemble model table. -You can create a new ensemble model by calling the `create_ensemble` method of the trained ensemble model table with a DataJoint restriction (`key`). The passed restriction is used to restrict the trained model table and all models still present in the restricted table will be made part of the new ensemble. Note that the provided restriction must be able to restrict the dataset table down to a single entry because creating an ensemble consisting of models trained on different datasets is currently not supported. For your own reference you can also pass a comment when creating a new ensemble. +You can create a new ensemble model by calling the `create_ensemble` method of the trained ensemble model table with a +DataJoint restriction (`key`). The passed restriction is used to restrict the trained model table and all models still +present in the restricted table will be made part of the new ensemble. Note that the provided restriction must be able +to restrict the dataset table down to a single entry because creating an ensemble consisting of models trained on +different datasets is currently not supported. For your own reference you can also pass a comment when creating a new +ensemble. ##### Example @@ -97,35 +119,51 @@ You can create a new ensemble model by calling the `create_ensemble` method of t TrainedEnsembleModel().create_ensemble(key, comment="My ensemble") ``` -#### 2.2 Populating the Selector Table +#### 2.2 Populating the Objective Table -Before generating MEIs you have to populate the selector table by either calling its `populate` method if your implementation provides it or by manually inserting entries. +Before generating MEIs you have to populate the objective table by either calling its `populate` method if your +implementation provides it or by manually inserting entries. ##### Example ```python -Selector().populate() +Objective().populate() ``` #### 2.3 Configuring the Generation Process -Each MEI is generated according to a user-configurable method. You can specify a new method by adding it to the MEI method table using its `add_method` method (see example below). This method expects the name of a method function (`method_fn`) and method configuration object (`method_config`). +Each MEI is generated according to a user-configurable method. You can specify a new method by adding it to the MEI +method table using its `add_method` method (see example below). This method expects the name of a method function +(`method_fn`) and method configuration object (`method_config`). -The function name needs to be the absolute path to a callable object. A function that can be used to generate MEIs using gradient ascent is provided with the framework and its path is `mei.methods.gradient_ascent`. +The function name needs to be the absolute path to a callable object. A function that can be used to generate MEIs using +gradient ascent is provided with the framework and its path is `mei.methods.gradient_ascent`. You can also implement +your own function and use it with the framework. Further information on how to do that can be found [here](#method). -The configuration object will be passed to the function by the framework and should contain information that will be used by the method function to alter its behavior. +The configuration object will be passed to the function by the framework and should contain information that will be +used by the method function to alter its behavior. -In the case of the provided function the configuration object is a dictionary. It contains information about which components to use when generating MEIs and how to configure those components. A component must be a callable object that must return another callable object when called. The configuration dictionary contains the absolute path (`"path"`) to the corresponding component and can additionally contain a set of keyword arguments (`"kwargs"`) that will be passed to the corresponding component when it is initially called. Below is a list of supported components: +In the case of the provided function the configuration object is a dictionary. It contains information about which +components to use when generating MEIs and how to configure those components. A component must be a callable object that +must return another callable object when called. The configuration dictionary contains the absolute path (`"path"`) to +the corresponding component and can additionally contain a set of keyword arguments (`"kwargs"`) that will be passed to +the corresponding component when it is initially called. Below is a list of supported components: -* `"device"`: Required, must be either `"cpu"` or `"cuda"`. The MEI will be generated on the CPU or the GPU depending on this value +* `"device"`: Required, must be either `"cpu"` or `"cuda"`. The MEI will be generated on the CPU or the GPU depending on + this value * `"initial"`: Required component used to generate an initial guess from which the MEI generation process will start -* `"optimizer"`: Required component used to optimize the input to the model and in turn generate the MEI. Must be a PyTorch-style optimizer class -* `"stopper"`: Required component used to determine whether or not to stop the MEI generation process in each iteration based on a user-defined condition +* `"optimizer"`: Required component used to optimize the input to the model and in turn generate the MEI. Must be a + PyTorch-style optimizer class +* `"stopper"`: Required component used to determine whether or not to stop the MEI generation process in each iteration + based on a user-defined condition * `"transform"`: Optional component used to transform the input before passing it through the model -* `"regularization"`: Optional component used to compute a regularization term from the (transformed) input that is added to the model's output before taking the optimization step +* `"regularization"`: Optional component used to compute a regularization term from the (transformed) input that is + added to the model's output before taking the optimization step * `"precondition"`: Optional component used to precondition the gradient -* `"postprocessing"`: Optional component that applies an operation to the input after each optimization step. The operation performed by this component does not influence the gradient -* `"objectives"`: Optional component that consists of a list of sub-components. Each sub-component tracks an objective over the duration of the generation process +* `"postprocessing"`: Optional component that applies an operation to the input after each optimization step. The + operation performed by this component does not influence the gradient +* `"objectives"`: Optional component that consists of a list of sub-components. Each sub-component tracks an objective + over the duration of the generation process You can completely omit optional components from the configuration dictionary if you do not want to use them. @@ -150,7 +188,8 @@ MEIMethod().add_method(method_fn, method_config, comment="My MEI method") #### 2.4 Specifying a Seed -Next you have to specify a seed to make the MEI generation process random but reproducible by inserting a seed into the MEI seed table. +Next you have to specify a seed to make the MEI generation process random but reproducible by inserting a seed into the +MEI seed table. ##### Example @@ -163,11 +202,16 @@ MEISeed().insert1({"mei_seed": 42}) #### 2.5 Populating the MEI Table -After configuring everything you can generate MEIs by calling the `populate` method of your `MEI` table. The table will insert one row for each generated MEI which itself can be found in the `mei` attribute. Additionally each MEI is associated with a score and an output object which can be found in the `score` and `output` attributes, respectively. +After configuring everything you can generate MEIs by calling the `populate` method of your `MEI` table. The table will +insert one row for each generated MEI which itself can be found in the `mei` attribute. Additionally each MEI is +associated with a score and an output object which can be found in the `score` and `output` attributes, respectively. -The score should express how well the generation process went but what it exactly represents is up to the used method function. In the case of the provided function it represents the final model evaluation. +The score should express how well the generation process went but what it exactly represents is up to the used method +function. In the case of the provided function it represents the final model evaluation. -The output object is an object that is returned by the method function at the end of the generation process. The included function will return a dictionary that contains the values of the objectives that were tracked during the generation process. +The output object is an object that is returned by the method function at the end of the generation process. +The included function will return a dictionary that contains the values of the objectives that were tracked during the +generation process. Note that the `mei` and `output` attributes are stored externally. @@ -179,31 +223,61 @@ MEI().populate() ## State -Instances of the `State` class contain information describing a particular state encountered during the optimization process. This information is used by various components in the framework. The attributes of a state instance are: +Instances of the `State` class contain information describing a particular state encountered during the optimization +process. This information is used by various components in the framework. The attributes of a state instance are: * `i_iter`: An integer representing the index of the optimization step this state corresponds to * `evaluation`: A float representing the evaluation of the model in response to the current input -* `reg_term`: A float representing the current regularization term added to the evaluation before the optimization step represented by this state was made. This value will be zero if no transformation is used -* `input_`: A tensor representing the untransformed input to the model. This will be identical to the post-processed input from the last step for all steps except the first one -* `transformed_input`: A tensor representing the transformed input to the model. This will be identical to the untransformed input if no transformation is used -* `post_processed_input`: A tensor representing the post-processed input. This will be identical to the untransformed input if no post-processing is done +* `reg_term`: A float representing the current regularization term added to the evaluation before the optimization step + represented by this state was made. This value will be zero if no transformation is used +* `input_`: A tensor representing the untransformed input to the model. This will be identical to the post-processed + input from the last step for all steps except the first one +* `transformed_input`: A tensor representing the transformed input to the model. This will be identical to the + untransformed input if no transformation is used +* `post_processed_input`: A tensor representing the post-processed input. This will be identical to the untransformed + input if no post-processing is done * `grad`: A tensor representing the gradient -* `preconditioned_grad` : A tensor representing the preconditioned gradient. This will be identical to the gradient if no preconditioning is done +* `preconditioned_grad` : A tensor representing the preconditioned gradient. This will be identical to the gradient if + no preconditioning is done * `stopper_output`: An object returned by the stopper component. ## Components This section describes each component type in greater detail and how you can implement your own variants. +### Method + +This component is the point of entry for the whole optimization process. + +It will be called with your NNFabrik-style dataloaders dictionary, your model, your configuration object and an +integer representing the seed. It must return the MEI, a float representing the score the MEI achieved and an output +object. The MEI and the output object must be compatible with PyTorch's `save` function. + +##### Example + +```python +def method(dataloaders, model, config, seed): + """Generates a MEI.""" + return mei, score, output +``` + +After you have implemented your method you can use it by adding it to the MEI method table as described +[here](#23-configuring-the-generation-process). + ### Stopper -This component is used to check whether or not to stop the MEI generation process based on the current state of the generation process. +This component is used to check whether or not to stop the MEI generation process based on the current state of the +generation process. -All custom stoppers must implement the `__call__` method and they should inherit from a abstract base class (ABC) called `OptimizationStopper`. The stopper will be called after each optimization step with the current state of the optimization process. It must return `(True, output)` if the optimization process is to be stopped and `(False, output)` otherwise. `output` can be any object or `None` but it can not be omitted. +All custom stoppers must implement the `__call__` method and they should inherit from a abstract base class (ABC) called +`OptimizationStopper`. The stopper will be called after each optimization step with the current state of the +optimization process. It must return `(True, output)` if the optimization process is to be stopped and +`(False, output)` otherwise. `output` can be any object or `None` but it can not be omitted. ##### Example -Below is the implementation of a custom stopper that stops the MEI generation process once the model's evaluation reaches a user-specified threshold: +Below is the implementation of a custom stopper that stops the MEI generation process once the model's evaluation +reaches a user-specified threshold: ```python """Contents of mymodule.py""" @@ -230,4 +304,4 @@ method_config = { "stopper": {"path": "mymodule.EvaluationThreshold", "kwargs": {"threshold": 2.5}}, ... } -``` \ No newline at end of file +``` diff --git a/id_rsa.enc b/id_rsa.enc deleted file mode 100644 index 747a1ff..0000000 Binary files a/id_rsa.enc and /dev/null differ diff --git a/mei/main.py b/mei/main.py index 719b5f6..bace5a5 100755 --- a/mei/main.py +++ b/mei/main.py @@ -26,10 +26,10 @@ class Member(mixins.TrainedEnsembleModelTemplateMixin.Member, dj.Part): """Member table template.""" -class CSRFV1SelectorTemplate(mixins.CSRFV1SelectorTemplateMixin, dj.Computed): - """CSRF V1 selector table template. +class CSRFV1ObjectiveTemplate(mixins.CSRFV1ObjectiveTemplateMixin, dj.Computed): + """CSRF V1 objective table template. - To create a functional "CSRFV1Selector" table, create a new class that inherits from this template and decorate it + To create a functional "CSRFV1Objective" table, create a new class that inherits from this template and decorate it with your preferred Datajoint schema. By default, the created table will point to the "Dataset" table in the Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting the class attribute called "dataset_table". @@ -52,8 +52,8 @@ class MEITemplate(mixins.MEITemplateMixin, dj.Computed): """MEI table template. To create a functional "MEI" table, create a new class that inherits from this template and decorate it with your - preferred Datajoint schema. Next assign your trained model (or trained ensemble model) and your selector table to - the class variables called "trained_model_table" and "selector_table". By default, the created table will point to + preferred Datajoint schema. Next assign your trained model (or trained ensemble model) and your objective table to + the class variables called "trained_model_table" and "objective_table". By default, the created table will point to the "MEIMethod" table in the Datajoint schema called "nnfabrik.main". This behavior can be changed by overwriting the class attribute called "method_table". """ diff --git a/mei/mixins.py b/mei/mixins.py index e2a4111..52e4317 100755 --- a/mei/mixins.py +++ b/mei/mixins.py @@ -116,7 +116,7 @@ def _load_ensemble_model( ) -class CSRFV1SelectorTemplateMixin: +class CSRFV1ObjectiveTemplateMixin: definition = """ # contains information that can be used to map a neuron's id to its corresponding integer position in the output of # the model. diff --git a/mei/modules.py b/mei/modules.py index eb31d8d..45991da 100644 --- a/mei/modules.py +++ b/mei/modules.py @@ -5,6 +5,7 @@ import torch from torch import Tensor from torch.nn import Module, ModuleList +from collections.abc import Iterable class EnsembleModel(Module): @@ -66,7 +67,7 @@ def __init__( if target_fn is None: target_fn = lambda x: x self.model = model - self.constraint = constraint + self.constraint = constraint if (isinstance(constraint, Iterable) or constraint is None) else [constraint] self.forward_kwargs = forward_kwargs if forward_kwargs else dict() self.target_fn = target_fn @@ -84,7 +85,7 @@ def __call__(self, x: Tensor, *args, **kwargs) -> Tensor: output = self.model(x, *args, **self.forward_kwargs, **kwargs) return ( self.target_fn(output) - if (not self.constraint and self.constraint != 0) + if self.constraint is None or len(self.constraint) == 0 else self.target_fn(output[:, self.constraint]) ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..55ec8d7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 120 diff --git a/tests/unit/mixins/test_csrfv1_selector_template_mixin.py b/tests/unit/mixins/test_csrfv1_selector_template_mixin.py index 68bc8cb..87a9905 100644 --- a/tests/unit/mixins/test_csrfv1_selector_template_mixin.py +++ b/tests/unit/mixins/test_csrfv1_selector_template_mixin.py @@ -6,11 +6,11 @@ @pytest.fixture -def selector_template(dataset_table): - selector_template = mixins.CSRFV1SelectorTemplateMixin - selector_template.dataset_table = dataset_table - selector_template.dataset_fn = "dataset_fn" - return selector_template +def objective_template(dataset_table): + objective_template = mixins.CSRFV1ObjectiveTemplateMixin + objective_template.dataset_table = dataset_table + objective_template.dataset_fn = "dataset_fn" + return objective_template @pytest.fixture @@ -18,17 +18,17 @@ def dataset_table(): return MagicMock(name="dataset_table") -def test_if_key_source_is_correct(selector_template, dataset_table): +def test_if_key_source_is_correct(objective_template, dataset_table): dataset_table.return_value.__and__.return_value = "key_source" - assert selector_template()._key_source == "key_source" + assert objective_template()._key_source == "key_source" dataset_table.return_value.__and__.assert_called_once_with(dict(dataset_fn="dataset_fn")) class TestMake: @pytest.fixture - def selector_template(self, selector_template, dataset_table, insert): - selector_template.insert = insert - return selector_template + def objective_template(self, objective_template, dataset_table, insert): + objective_template.insert = insert + return objective_template @pytest.fixture def dataset_table(self, dataset_table): @@ -43,26 +43,26 @@ def insert(self): def get_mappings(self): return MagicMock(return_value="mappings") - def test_if_dataset_config_is_correctly_fetched(self, key, selector_template, dataset_table, get_mappings): - selector_template().make(key, get_mappings=get_mappings) + def test_if_dataset_config_is_correctly_fetched(self, key, objective_template, dataset_table, get_mappings): + objective_template().make(key, get_mappings=get_mappings) dataset_table.return_value.__and__.assert_called_once_with(key) dataset_table.return_value.__and__.return_value.fetch1.assert_called_once_with("dataset_config") - def test_if_get_mappings_is_correctly_called(self, key, selector_template, get_mappings): - selector_template().make(key, get_mappings=get_mappings) + def test_if_get_mappings_is_correctly_called(self, key, objective_template, get_mappings): + objective_template().make(key, get_mappings=get_mappings) get_mappings.assert_called_once_with("dataset_config", key) - def test_if_mappings_are_correctly_inserted(self, key, selector_template, insert, get_mappings): - selector_template().make(key, get_mappings=get_mappings) + def test_if_mappings_are_correctly_inserted(self, key, objective_template, insert, get_mappings): + objective_template().make(key, get_mappings=get_mappings) insert.assert_called_once_with("mappings") -class TestGetOutputSelectedModel: +class TestGetObjective: @pytest.fixture - def selector_template(self, selector_template, constrained_output_model, magic_and): - selector_template.constrained_output_model = constrained_output_model - selector_template.__and__ = magic_and - return selector_template + def objective_template(self, objective_template, constrained_output_model, magic_and): + objective_template.constrained_output_model = constrained_output_model + objective_template.__and__ = magic_and + return objective_template @pytest.fixture def constrained_output_model(self): @@ -78,19 +78,19 @@ def magic_and(self): def model(self): return MagicMock(name="model") - def test_if_neuron_position_and_session_id_are_correctly_fetched(self, key, model, selector_template, magic_and): - selector_template().get_output_selected_model(model, key) + def test_if_neuron_position_and_session_id_are_correctly_fetched(self, key, model, objective_template, magic_and): + objective_template().get_objective(model, key) magic_and.assert_called_once_with(key) magic_and.return_value.fetch1.assert_called_once_with("neuron_position", "session_id") def test_if_constrained_output_model_is_correctly_initialized( - self, key, model, selector_template, constrained_output_model + self, key, model, objective_template, constrained_output_model ): - selector_template().get_output_selected_model(model, key) + objective_template().get_objective(model, key) constrained_output_model.assert_called_once_with( model, "neuron_pos", forward_kwargs=dict(data_key="session_id") ) - def test_if_output_selected_model_is_correctly_returned(self, key, model, selector_template): - output_selected_model = selector_template().get_output_selected_model(model, key) - assert output_selected_model == "constrained_output_model" + def test_if_objective_is_correctly_returned(self, key, model, objective_template): + objective = objective_template().get_objective(model, key) + assert objective == "constrained_output_model" diff --git a/tests/unit/mixins/test_mei_template_mixin.py b/tests/unit/mixins/test_mei_template_mixin.py index 331a222..ee67d5d 100644 --- a/tests/unit/mixins/test_mei_template_mixin.py +++ b/tests/unit/mixins/test_mei_template_mixin.py @@ -37,8 +37,8 @@ def test_if_model_loader_is_correctly_initialized(mei_template, trained_model_ta class TestMake: @pytest.fixture - def mei_template(self, mei_template, selector_table, method_table, seed_table, insert1, save, model_loader_class): - mei_template.selector_table = selector_table + def mei_template(self, mei_template, objective_table, method_table, seed_table, insert1, save, model_loader_class): + mei_template.objective_table = objective_table mei_template.method_table = method_table mei_template.seed_table = seed_table mei_template.insert1 = insert1 @@ -53,10 +53,10 @@ def mei_template(self, mei_template, selector_table, method_table, seed_table, i return mei_template @pytest.fixture - def selector_table(self): - selector_table = MagicMock(name="selector_table") - selector_table.return_value.get_output_selected_model.return_value = "output_selected_model" - return selector_table + def objective_table(self): + objective_table = MagicMock(name="objective_table") + objective_table.return_value.get_objective.return_value = "objective" + return objective_table @pytest.fixture def method_table(self): @@ -78,9 +78,9 @@ def test_if_model_is_correctly_loaded(self, key, mei_template, model_loader): mei_template().make(key) model_loader.load.assert_called_once_with(key=key) - def test_if_correct_model_output_is_selected(self, key, mei_template, selector_table): + def test_if_get_objective_is_correctly_called(self, key, mei_template, objective_table): mei_template().make(key) - selector_table.return_value.get_output_selected_model.assert_called_once_with("model", key) + objective_table.return_value.get_objective.assert_called_once_with("model", key) def test_if_seed_is_correctly_fetched(self, key, mei_template, seed_table): mei_template().make(key) @@ -89,9 +89,7 @@ def test_if_seed_is_correctly_fetched(self, key, mei_template, seed_table): def test_if_mei_is_correctly_generated(self, key, mei_template, method_table): mei_template().make(key) - method_table.return_value.generate_mei.assert_called_once_with( - "dataloaders", "output_selected_model", key, "seed" - ) + method_table.return_value.generate_mei.assert_called_once_with("dataloaders", "objective", key, "seed") def test_if_mei_is_correctly_saved(self, key, mei_template, save): mei_template().make(key)