Skip to content

Commit

Permalink
Remove BoltzmannCombination references.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaminow committed Jul 24, 2024
1 parent e74d59c commit 3d0de92
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 37 deletions.
6 changes: 0 additions & 6 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,10 @@ class CombinationConfig(StringEnum):
* mean: :py:class:`MeanCombination <mtenn.combination.MeanCombination>`
* max: :py:class:`MaxCombination <mtenn.combination.MaxCombination>`
* boltzmann:
:py:class:`BoltzmannCombination <mtenn.combination.BoltzmannCombination>`
"""

mean = "mean"
max = "max"
boltzmann = "boltzmann"


class ModelConfigBase(BaseModel):
Expand Down Expand Up @@ -273,8 +269,6 @@ def build(self) -> mtenn.model.Model:
mtenn_combination = mtenn.combination.MaxCombination(
negate_preds=self.max_comb_neg, pred_scale=self.max_comb_scale
)
case CombinationConfig.boltzmann:
mtenn_combination = mtenn.combination.BoltzmannCombination()
case None:
mtenn_combination = None

Expand Down
32 changes: 1 addition & 31 deletions mtenn/tests/test_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from mtenn.combination import MeanCombination, MaxCombination, BoltzmannCombination
from mtenn.combination import MeanCombination, MaxCombination
from mtenn.conversion_utils.schnet import SchNet


Expand Down Expand Up @@ -89,33 +89,3 @@ def test_max_combination(models_and_inputs):
for n, p in model_test.named_parameters()
]
)


def test_boltzmann_combination(models_and_inputs):
model_test, model_ref, inp_list, target, loss_func = models_and_inputs

# Ref calc
pred_list = torch.stack([model_ref(X)[0] for X in inp_list])
w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0))
pred_ref = torch.dot(w.flatten(), pred_list.flatten())
loss = loss_func(pred_ref, target)
loss.backward()

# Finish setting up GroupedModel
model_test = SchNet.get_model(
model_test, grouped=True, strategy="complex", combination=BoltzmannCombination()
)

# Test GroupedModel
pred_test, _ = model_test(inp_list)
loss = loss_func(pred_test, target)
loss.backward()

# Compare
ref_param_dict = dict(model_ref.named_parameters())
assert all(
[
np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7)
for n, p in model_test.named_parameters()
]
)

0 comments on commit 3d0de92

Please sign in to comment.