diff --git a/docs/docs/combination.rst b/docs/docs/combination.rst index 0cd4d8a..bfd6425 100644 --- a/docs/docs/combination.rst +++ b/docs/docs/combination.rst @@ -20,7 +20,7 @@ The prediction for each pose is generated by the same single-pose model (:math:` .. math:: - \hat{y}_i = f( \mathrm{X}_i, \theta ) + \hat{y}_i = f( \text{X}_i, \theta ) and the final prediction for this compound is found by applying the combination function (:math:`h`) to this set of individual predictions: @@ -32,13 +32,13 @@ We then calculate the loss of our prediction compared to a target value .. math:: - \mathrm{loss} = L ( \hat{y}(\theta), y ) + \text{loss} = L ( \hat{y}(\theta), y ) and backprop is performed by calcuation the gradient of that loss wrt the model parameters: .. math:: - \frac{\partial \mathrm{loss}}{\partial \theta} = \frac{\partial L}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial \theta} + \frac{\partial \text{loss}}{\partial \theta} = \frac{\partial L}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial \theta} The :math:`\frac{\partial L}{\partial \hat{y}}` term can be calculated automatically using the ``pytorch.autograd`` capabilities. However, because we've decoupled the single-pose model predictions from the overall multi-pose prediction, we must manually account for the relation between the :math:`\frac{\partial \hat{y}}{\partial \theta}` term and the individual gradients that we calculated during the forward pass (:math:`\frac{\partial \hat{y}_i}{\partial \theta}`). @@ -50,3 +50,123 @@ Arbitrarily, this will be some function (:math:`g`) that depends on the individu g( \hat{y}_1, ..., \hat{y}_n, \frac{\partial \hat{y}_1}{\partial \theta}, ..., \frac{\partial \hat{y}_n}{\partial \theta} ) In practice, this function :math:`g` will need to be analytically determined and manually implemented within the ``Combination`` block (see :ref:`the guide ` for more practical information). + +.. _implemented-combs: + +Math for Implemented Combinations +---------------------------------- + +Below, we detail the math required for appropriately combining gradients. +This math is used in the ``backward`` pass in the various ``Combination`` classes. + +.. _imp-comb-loss-fn: + +Loss Functions +^^^^^^^^^^^^^^ + +We anticipate these ``Combination`` methods being used with a linear combination of two types of loss functions: + + * Loss based on the final combined prediction (ie :math:`L = f(\Delta \text{G} (\theta))`) + + * Loss based on a linear combination of the per-pose predictions (ie :math:`L = f(\Delta \text{G}_1 (\theta), \Delta \text{G}_2 (\theta), ...)`) + +Ultimately for backprop we need to return the gradients of the loss wrt each model parameter. +The gradients for each of these types of losses is given below. + +Combined Prediction +""""""""""""""""""" + +.. math:: + :label: comb-grad + + \frac{\partial L}{\partial \theta} = + \frac{\partial L}{\partial \Delta \text{G}} + \frac{\partial \Delta \text{G}}{\partial \theta} + +The :math:`\frac{\partial L}{\partial \Delta \text{G}}` part of this equation will be a scalar that is calculated automatically by ``pytorch`` and fed to our ``Combination`` class. +The :math:`\frac{\partial \Delta \text{G}}{\partial \theta}` parts will be computed internally. + +Per-Pose Prediction +""""""""""""""""""" + +Because we assume this loss is based on a linear combination of the individual :math:`\Delta \text{G}_i` predictions, we can decompose the loss as: + +.. math:: + :label: pose-grad + + \frac{\partial L}{\partial \theta} = + \sum_{i=1}^N + \frac{\partial L}{\partial \Delta \text{G}_i} + \frac{\partial \Delta \text{G}_i}{\partial \theta} + +As before, the :math:`\frac{\partial L}{\partial \Delta \text{G}_i}` parts of this equation will be scalars calculated automatically by ``pytorch`` and fed to our ``Combination`` class, and the :math:`\frac{\partial \Delta \text{G}}{\partial \theta}` parts will be computed internally. + +.. _mean-comb-imp: + +Mean Combination +^^^^^^^^^^^^^^^^ + +This is mostly included as an example, but it can be illustrative. + +.. math:: + :label: mean-comb-pred + + \Delta \text{G}(\theta) = \frac{1}{N} \sum_{i=1}^{N} \Delta \text{G}_i (\theta) + +.. math:: + :label: mean-comb-grad + + \frac{\partial \Delta \text{G}(\theta)}{\partial \theta} = \frac{1}{N} \sum_{i=1}^{N} \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta} + +.. _max-comb-imp: + +Max Combination +^^^^^^^^^^^^^^^ + +This will likely be the more useful of the currently implemented ``Combination`` implementations. +In the below equations, we define the following variables: + + * :math:`n` : A sign multiplier taking the value of :math:`-1` if we are taking the min value (generally the case if the inputs are :math:`\Delta \text{G}` values) or :math:`1` if we are taking the max + * :math:`t` : A scaling value that will bring the final combined value closer to the actual value of the max/min of the input values (see `here `_ for more details). + Setting :math:`t = 1` reduces this operation to the LogSumExp operation + +.. math:: + :label: max-comb-pred + + \Delta \text{G}(\theta) = n \frac{1}{t} \text{ln} \sum_{i=1}^N \text{exp} (n t \Delta \text{G}_i (\theta)) + +We define a a constant :math:`Q` for simplicity as well as for numerical stability: + +.. math:: + :label: max-comb-q + + Q = \text{ln} \sum_{i=1}^N \text{exp} (n t \Delta \text{G}_i (\theta)) + +.. math:: + :label: max-comb-grad-initial + + \frac{\partial \Delta \text{G}(\theta)}{\partial \theta} = + n^2 + \frac{1}{\sum_{i=1}^N \text{exp} (n t \Delta \text{G}_i (\theta))} + \sum_{i=1}^N \left[ + \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta} \text{exp} (n t \Delta \text{G}_i (\theta)) + \right] + +Substituting in :math:`Q`: + +.. math:: + :label: max-comb-grad-sub + + \frac{\partial \Delta \text{G}(\theta)}{\partial \theta} = + \frac{1}{\text{exp}(Q)} + \sum_{i=1}^N \left[ + \text{exp} \left( n t \Delta \text{G}_i (\theta) \right) \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta} + \right] + +.. math:: + :label: max-comb-grad-final + + \frac{\partial \Delta \text{G}(\theta)}{\partial \theta} = + \sum_{i=1}^N \left[ + \text{exp} \left( n t \Delta \text{G}_i (\theta) - Q \right) \frac{\partial \Delta \text{G}_i (\theta)}{\partial \theta} + \right] diff --git a/mtenn/combination.py b/mtenn/combination.py index 0d07de7..e2cb8c2 100644 --- a/mtenn/combination.py +++ b/mtenn/combination.py @@ -6,6 +6,8 @@ predictions into a single multi-pose prediction. For more details on the implementation of these classes, see the :ref:`comb-docs-page` docs page and the guide on :ref:`new-combination-guide`. + +All equations referenced here correspond to those in :ref:`implemented-combs`. """ import abc @@ -25,20 +27,37 @@ class Combination(torch.nn.Module, abc.ABC): def forward(self, pred_list, grad_dict, param_names, *model_params): """ This function signature should be the same for any ``Combination`` subclass - implementation. + implementation. The return values should be: + + * ``torch.Tensor``: Scalar-value tensor giving the final combined prediction + + * ``torch.Tensor``: Tensor of shape ``(n_predictions,)`` giving the input + per-pose predictions. This is necessary for ``Pytorch`` to track the + gradients of these predictions in the case of eg a cross-entropy loss on the + per-pose predictions Parameters ---------- pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to be combined + List of :math:`\mathrm{\Delta G}` predictions to be combined, shape of + ``(n_predictions,)`` grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients + Dict mapping from parameter name to list of gradients. Should contain + ``n_model_parameters`` entries, with each entry mapping to a list of + ``n_predictions`` tensors. Each of these tensors is a ``detach`` ed gradient + so the shape of each tensor will depend on the model parameter it + corresponds to, but the shapes of each tensor in any given entry should be + identical param_names: List[str] - List of parameter names - model_params: torch.Tensor + List of parameter names. Should contain ``n_model_parameters`` entries, + corresponding 1:1 with the keys in ``grad_dict`` + model_params: List[torch.Tensor] Actual parameters that we'll return the gradients for. Each param - should be passed individually (ie not as a list) for the backward pass to - work right. + should be passed directly for the backward pass to + work right. These tensors should correspond 1:1 with and should be in the + same order as the entries in ``param_names`` (ie the ``i`` th entry in + ``param_names`` should be the name of the ``i`` th model parameter in + ``model_params``) """ raise NotImplementedError("Must implement the `forward` method.") @@ -101,32 +120,12 @@ def join_grad_dict(grad_dict_keys, grad_dict_tensors): class MeanCombination(Combination): """ - Combine a list of predictions by taking the mean. - - .. math:: - - \Delta G = \\frac{1}{N} \sum_{n=1}^{N} \Delta G_n + Combine a list of predictions by taking the mean. See the docs for + :py:class:`MeanCombinationFunc ` for more + details. """ def forward(self, pred_list, grad_dict, param_names, *model_params): - """ - Wrapper around the :py:class:`MeanCombinationFunc - ` class's ``apply`` method. - - Parameters - ---------- - pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to be combined by their mean - grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients - param_names: List[str] - List of parameter names - model_params: torch.Tensor - Actual parameters that we'll return the gradients for. Each param - should be passed individually (ie not as a list) for the backward pass to - work right. - """ - return MeanCombinationFunc.apply( pred_list, grad_dict, param_names, *model_params ) @@ -137,7 +136,13 @@ class MeanCombinationFunc(torch.autograd.Function): Custom autograd function that will handle the gradient math for us for combining :math:`\mathrm{\Delta G}` predictions to their mean. - :meta public: + .. math:: + + \Delta \\text{G}(\\theta) = \\frac{1}{N} + \\sum_{i=1}^{N} \\Delta \\text{G}_i (\\theta) + + See :ref:`mean-comb-imp` for more details on the math. + """ @staticmethod @@ -148,26 +153,39 @@ def forward(pred_list, grad_dict, param_names, *model_params): Parameters ---------- pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to be combined by their mean + List of :math:`\mathrm{\Delta G}` predictions to be combined, shape of + ``(n_predictions,)`` grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients + Dict mapping from parameter name to list of gradients. Should contain + ``n_model_parameters`` entries, with each entry mapping to a list of + ``n_predictions`` tensors. Each of these tensors is a ``detach`` ed gradient + so the shape of each tensor will depend on the model parameter it + corresponds to, but the shapes of each tensor in any given entry should be + identical param_names: List[str] - List of parameter names - model_params: torch.Tensor + List of parameter names. Should contain ``n_model_parameters`` entries, + corresponding 1:1 with the keys in ``grad_dict`` + model_params: List[torch.Tensor] Actual parameters that we'll return the gradients for. Each param - should be passed individually (ie not as a list) for the backward pass to - work right. + should be passed directly for the backward pass to + work right. These tensors should correspond 1:1 with and should be in the + same order as the entries in ``param_names`` (ie the ``i`` th entry in + ``param_names`` should be the name of the ``i`` th model parameter in + ``model_params``) Returns ------- torch.Tensor - Mean of input :math:`\mathrm{\Delta G}` predictions. + Scalar-value tensor giving the mean of the input :math:`\mathrm{\Delta G}` + predictions + torch.Tensor + Tensor of shape ``(n_predictions,)`` giving the input per-pose predictions """ # Return mean of all preds all_preds = torch.stack(pred_list).flatten() final_pred = all_preds.mean(axis=None).detach() - return final_pred + return final_pred, all_preds @staticmethod def setup_context(ctx, inputs, output): @@ -204,7 +222,7 @@ def setup_context(ctx, inputs, output): ) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, comb_grad, pose_grads): """ Compute and return gradients for each parameter. @@ -212,27 +230,50 @@ def backward(ctx, grad_output): ---------- ctx Pytorch context manager - grad_output : torch.Tensor - Gradient of the loss wrt the prediction (from ``forward``) + comb_grad : torch.Tensor + Scalar-value tensor giving the + :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}}` term from + :eq:`comb-grad` + pose_grads : torch.Tensor + Tensor of shape ``(n_predictions,)``, giving the + :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}_i}` terms from + :eq:`pose-grad` """ # Unpack saved tensors preds, *other_tensors = ctx.saved_tensors - # Split up other_tensors + # First section of these tensors are the flattened lists of gradients from each + # individual pose or each model parameter grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] + # Reconstruct dict mapping from model parameter name to list of gradient tensors + # The ith entry in each list gives the gradient of the ith pose prediction wrt + # that model parameter grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) # Calculate final gradients for each parameter final_grads = {} for n, grad_list in grad_dict.items(): - final_grads[n] = torch.stack(grad_list, axis=-1).mean(axis=-1) - - # Adjust gradients by grad_output - for grad in final_grads.values(): - grad *= grad_output - - # Pull out return vals + # Compute the gradient contributions from any combined prediction loss, + # according to eqns (1), (4) + cur_final_grad = comb_grad * torch.stack(grad_list, axis=-1).mean(axis=-1) + + # Make sure lengths match up (should always be true but just in case) + if len(pose_grads) != len(grad_list): + raise RuntimeError("Mismatch in gradient lengths.") + + # Compute the gradient contributions from any per-pose prediction loss, + # according to eqn (2) + for pose_grad, param_grad in zip(pose_grads, grad_list): + cur_final_grad += pose_grad * param_grad + + # Store total gradient for each parameter + final_grads[n] = cur_final_grad.clone() + + # Return gradients for each of the model parameters that were passed in. Also + # need to return values for the other values that were passed to forward + # (pred_list, grad_dict, param_names), but these don't get gradients so we just + # return None return_vals = [None] * 3 + [final_grads[n] for n in ctx.param_names] return tuple(return_vals) @@ -240,54 +281,40 @@ def backward(ctx, grad_output): class MaxCombination(Combination): """ Approximate max/min of the predictions using the LogSumExp function for smoothness. - - .. math:: - - \Delta G = \\frac{-1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (-t \Delta G_n) + See the docs for :py:class:`MaxCombinationFunc + ` for more details. """ - def __init__(self, neg=True, scale=1000.0): + def __init__(self, negate_preds=True, pred_scale=1000.0): """ Parameters ---------- - neg : bool, default=True + negate_preds : bool, default=True Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned - scale : float, default=1000.0 + pred_scale : float, default=1000.0 Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation """ super(MaxCombination, self).__init__() - self.neg = neg - self.scale = scale + self.negate_preds = negate_preds + self.pred_scale = pred_scale def __repr__(self): - return f"MaxCombination(neg={self.neg}, scale={self.scale})" + return f"MaxCombination(negate_preds={self.negate_preds}, pred_scale={self.pred_scale})" def __str__(self): return repr(self) def forward(self, pred_list, grad_dict, param_names, *model_params): - """ - Wrapper around the :py:class:`MaxCombinationFunc - ` class's ``apply`` method. - - Parameters - ---------- - pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to find the max/min of - grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients - param_names: List[str] - List of parameter names - model_params: torch.Tensor - Actual parameters that we'll return the gradients for. Each param - should be passed individually (ie not as a list) for the backward pass to - work right. - """ return MaxCombinationFunc.apply( - self.neg, self.scale, pred_list, grad_dict, param_names, *model_params + self.negate_preds, + self.pred_scale, + pred_list, + grad_dict, + param_names, + *model_params, ) @@ -296,41 +323,99 @@ class MaxCombinationFunc(torch.autograd.Function): Custom autograd function that will handle the gradient math for us for taking the max/min of the :math:`\mathrm{\Delta G}` predictions. - :meta public: + For the ``forward`` pass, the final :math:`\mathrm{\Delta G}` prediction is + calculated according to the following: + + .. math:: + + n = \\begin{cases} + -1 & \\text{negate_preds} \\\\ + \\phantom{-}1 & \\text{not negate_preds} + \\end{cases} + + .. math:: + + t &= \\text{pred_scale} + + \Delta G &= n \\frac{1}{t} \mathrm{ln} \sum_{n=1}^N \mathrm{exp} (n t \Delta G_n) + + The logic and math behind this scaling approach are detailed `here + `_. + + See :ref:`max-comb-imp` for more details on the math. """ @staticmethod - def forward(neg, scale, pred_list, grad_dict, param_names, *model_params): + def forward( + negate_preds, pred_scale, pred_list, grad_dict, param_names, *model_params + ): """ Find the max/min of all input :math:`\mathrm{\Delta G}` predictions. Parameters ---------- - neg: bool + negate_preds: bool Negate the predictions before calculating the LSE, effectively finding the min. Preds are negated again before being returned - scale: float + pred_scale: float Fixed positive value to scale predictions by before taking the LSE. This tightens the bounds of the LSE approximation pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to be combined using LSE + List of :math:`\mathrm{\Delta G}` predictions to be combined, shape of + ``(n_predictions,)`` grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients + Dict mapping from parameter name to list of gradients. Should contain + ``n_model_parameters`` entries, with each entry mapping to a list of + ``n_predictions`` tensors. Each of these tensors is a ``detach`` ed gradient + so the shape of each tensor will depend on the model parameter it + corresponds to, but the shapes of each tensor in any given entry should be + identical param_names: List[str] - List of parameter names - model_params: torch.Tensor + List of parameter names. Should contain ``n_model_parameters`` entries, + corresponding 1:1 with the keys in ``grad_dict`` + model_params: List[torch.Tensor] Actual parameters that we'll return the gradients for. Each param - should be passed individually for the backward pass to work right. + should be passed directly for the backward pass to + work right. These tensors should correspond 1:1 with and should be in the + same order as the entries in ``param_names`` (ie the ``i`` th entry in + ``param_names`` should be the name of the ``i`` th model parameter in + ``model_params``) + + Returns + ------- + torch.Tensor + Scalar-value tensor giving the max/min of the input + :math:`\mathrm{\Delta G}` predictions + torch.Tensor + Tensor of shape ``(n_predictions,)`` giving the input per-pose predictions """ - neg = (-1) ** neg - # Calculate once for reuse later + # The value of negate_preds tells us if we are finding the max or min. If True, + # we are finding the min and need to flip the sign of each individual + # prediction, as well as the final combined prediction (this is the value n + # described in the class docstring and associated implementation math section) + negative_multiplier = -1 if negate_preds else 1 + + # Combine all torch tensors so we don't need to keep doing it at each step all_preds = torch.stack(pred_list).flatten() - adj_preds = neg * scale * all_preds.detach() + + # We use adj_preds here to store the adjusted per-pose prediction values. These + # values have been negated (if we are finding the min), and multiplied by our + # scale value, if given + # These values correspond to the values inside the exponential in eqn (5) (and + # subsequent equations) + adj_preds = negative_multiplier * pred_scale * all_preds.detach() + + # Although defining this intermediate value isn't as helpful/necessary in the + # forward pass, we do so anyway for consistency with the backward pass, where + # it will be necessary for numerical stability + # This corresponds to eqn (6) Q = torch.logsumexp(adj_preds, dim=0) - # Calculate the actual prediction - final_pred = (neg * Q / scale).detach() - return final_pred + # Perform the inverse adjustments we applied to the per-pose predictions, giving + # us (approximately) the original value of the max/min per-pose prediction + final_pred = (negative_multiplier * Q / pred_scale).detach() + + return final_pred, all_preds @staticmethod def setup_context(ctx, inputs, output): @@ -344,22 +429,33 @@ def setup_context(ctx, inputs, output): inputs : List List containing all the parameters that will get passed to ``forward`` output : torch.Tensor - Value returned from ``forward`` - """ - neg, scale, pred_list, grad_dict, param_names, *model_params = inputs + Values returned from ``forward`` + """ + # Unpack the inputs + ( + negate_preds, + pred_scale, + pred_list, + grad_dict, + param_names, + *model_params, + ) = inputs + # Break the grad dict up into lists of keys and corresponding lists of gradients grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) # Save non-Tensors for backward - ctx.neg = neg - ctx.scale = scale + ctx.negate_preds = negate_preds + ctx.pred_scale = pred_scale ctx.grad_dict_keys = grad_dict_keys ctx.param_names = param_names # Save Tensors for backward # Saving: - # * Predictions (1 tensor) - # * Grad tensors (N params * M poses tensors) + # * Predictions (1 tensor of shape (n_predictions,)) + # * Grad tensors (N params * M poses tensors, where all gradients corresponding + # to a given model parameter are adjacent, ie first M tensors are the + # per-pose gradients for the first model parameter, etc) # * Model param tensors (N params tensors) ctx.save_for_backward( torch.stack(pred_list).flatten(), @@ -368,7 +464,7 @@ def setup_context(ctx, inputs, output): ) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, comb_grad, pose_grads): """ Compute and return gradients for each parameter. @@ -376,28 +472,47 @@ def backward(ctx, grad_output): ---------- ctx Pytorch context manager - grad_output : torch.Tensor - Gradient of the loss wrt the prediction (from ``forward``) + comb_grad : torch.Tensor + Scalar-value tensor giving the + :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}}` term from + :eq:`comb-grad` + pose_grads : torch.Tensor + Tensor of shape ``(n_predictions,)``, giving the + :math:`\\frac{\\partial L}{\\partial \\Delta \\text{G}_i}` terms from + :eq:`pose-grad` """ # Unpack saved tensors preds, *other_tensors = ctx.saved_tensors - # Split up other_tensors + # First section of these tensors are the flattened lists of gradients from each + # individual pose or each model parameter grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] + # Reconstruct dict mapping from model parameter name to list of gradient tensors + # The ith entry in each list gives the gradient of the ith pose prediction wrt + # that model parameter grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) - # Begin calculations - neg = (-1) ** ctx.neg + # Set negation multiplier for finding max/min (see docstring and associated + # implementation math section for more details) + negative_multiplier = -1 if ctx.negate_preds else 1 + + # We use adj_preds here to store the adjusted per-pose prediction values. These + # values have been negated (if we are finding the min), and multiplied by our + # scale value, if given + # These values correspond to the values inside the exponential in eqn (5) (and + # subsequent equations) + adj_preds = negative_multiplier * ctx.pred_scale * preds.detach() - # Calculate once for reuse later - adj_preds = neg * ctx.scale * preds.detach() + # Calculate our normalizing constant (eqn (6)) Q = torch.logsumexp(adj_preds, dim=0) # Calculate final gradients for each parameter final_grads = {} for n, grad_list in grad_dict.items(): - final_grads[n] = ( + # Compute the gradient contributions from any combined prediction loss, + # according to eqns (1), (9) + cur_final_grad = comb_grad * ( torch.stack( [ grad * (pred - Q).exp() @@ -409,192 +524,21 @@ def backward(ctx, grad_output): .sum(axis=-1) ) - # Adjust gradients by grad_output - for grad in final_grads.values(): - grad *= grad_output - - # Pull out return vals - return_vals = [None] * 5 + [final_grads[n] for n in ctx.param_names] - return tuple(return_vals) - - -class BoltzmannCombination(Combination): - """ - Combine a list of :math:`\mathrm{\Delta G}` predictions according to their - Boltzmann weight. Treat energy in implicit kT units. - - .. math:: - - \Delta G &= \sum_{n=1}^{N} w_n \Delta G_n - - w_n &= \mathrm{exp} \\left[ - -\Delta G_n - \mathrm{ln} \\sum_{i=1}^N \\mathrm{exp} (-\Delta G_i ) \\right] - """ - - def forward(self, pred_list, grad_dict, param_names, *model_params): - """ - Wrapper around the :py:class:`BoltzmannCombinationFunc - ` class's ``apply`` method. + # Make sure lengths match up (should always be true but just in case) + if len(pose_grads) != len(grad_list): + raise RuntimeError("Mismatch in gradient lengths.") - Parameters - ---------- - pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to combine - grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients - param_names: List[str] - List of parameter names - model_params: torch.Tensor - Actual parameters that we'll return the gradients for. Each param - should be passed individually (ie not as a list) for the backward pass to - work right. - """ - return BoltzmannCombinationFunc.apply( - pred_list, grad_dict, param_names, *model_params - ) + # Compute the gradient contributions from any per-pose prediction loss, + # according to eqn (2) + for pose_grad, param_grad in zip(pose_grads, grad_list): + cur_final_grad += pose_grad * param_grad + # Store total gradient for each parameter + final_grads[n] = cur_final_grad.clone() -class BoltzmannCombinationFunc(torch.autograd.Function): - """ - Custom autograd function that will handle the gradient math for us for combining - :math:`\mathrm{\Delta G}` predictions by Boltzmann weighting. - - :meta public: - """ - - @staticmethod - def forward(pred_list, grad_dict, param_names, *model_params): - """ - Combine input :math:`\mathrm{\Delta G}` predictions. - - Parameters - ---------- - pred_list: List[torch.Tensor] - List of :math:`\mathrm{\Delta G}` predictions to be combined using LSE - grad_dict: dict[str, List[torch.Tensor]] - Dict mapping from parameter name to list of gradients - param_names: List[str] - List of parameter names - model_params: torch.Tensor - Actual parameters that we'll return the gradients for. Each param - should be passed individually for the backward pass to work right. - """ - # Save for later so we don't have to keep redoing this - adj_preds = -torch.stack(pred_list).flatten().detach() - - # First calculate the normalization factor - Q = torch.logsumexp(adj_preds, dim=0) - - # Calculate w - w = (adj_preds - Q).exp() - - # Calculate final pred - final_pred = torch.dot(w, -adj_preds) - - return final_pred - - @staticmethod - def setup_context(ctx, inputs, output): - """ - Store data for backward pass. - - Parameters - ---------- - ctx - Pytorch context manager - inputs : List - List containing all the parameters that will get passed to ``forward`` - output : torch.Tensor - Value returned from ``forward`` - """ - pred_list, grad_dict, param_names, *model_params = inputs - - grad_dict_keys, grad_dict_tensors = Combination.split_grad_dict(grad_dict) - - # Save non-Tensors for backward - ctx.grad_dict_keys = grad_dict_keys - ctx.param_names = param_names - - # Save Tensors for backward - # Saving: - # * Predictions (1 tensor) - # * Grad tensors (N params * M poses tensors) - # * Model param tensors (N params tensors) - ctx.save_for_backward( - torch.stack(pred_list).flatten(), - *grad_dict_tensors, - *model_params, - ) - - @staticmethod - def backward(ctx, grad_output): - """ - Compute and return gradients for each parameter. - - Parameters - ---------- - ctx - Pytorch context manager - grad_output : torch.Tensor - Gradient of the loss wrt the prediction (from ``forward``) - """ - # Unpack saved tensors - preds, *other_tensors = ctx.saved_tensors - - # Split up other_tensors - grad_dict_tensors = other_tensors[: len(ctx.grad_dict_keys)] - - grad_dict = Combination.join_grad_dict(ctx.grad_dict_keys, grad_dict_tensors) - - # Begin calculations - # Save for later so we don't have to keep redoing this - adj_preds = -preds.detach() - - # First calculate the normalization factor - Q = torch.logsumexp(adj_preds, dim=0) - - # Calculate w - w = (adj_preds - Q).exp() - - # Calculate dQ/d_theta - dQ = { - n: -torch.stack( - [(pred - Q).exp() * grad for pred, grad in zip(adj_preds, grad_list)], - axis=-1, - ).sum(axis=-1) - for n, grad_list in grad_dict.items() - } - - # Calculate dw/d_theta - dw = { - n: [ - (pred - Q).exp() * (-grad - dQ[n]) - for pred, grad in zip(adj_preds, grad_list) - ] - for n, grad_list in grad_dict.items() - } - - # Calculate final grads - final_grads = {} - for n, grad_list in grad_dict.items(): - final_grads[n] = ( - torch.stack( - [ - w_grad * -pred + w_val * grad - for w_grad, pred, w_val, grad in zip( - dw[n], adj_preds, w, grad_list - ) - ], - axis=-1, - ) - .detach() - .sum(axis=-1) - ) - - # Adjust gradients by grad_output - for grad in final_grads.values(): - grad *= grad_output - - # Pull out return vals - return_vals = [None] * 3 + [final_grads[n] for n in ctx.param_names] + # Return gradients for each of the model parameters that were passed in. Also + # need to return values for the other values that were passed to forward + # (negate_preds, pred_scale, pred_list, grad_dict, param_names), but these + # don't get gradients so we just return None + return_vals = [None] * 5 + [final_grads[n] for n in ctx.param_names] return tuple(return_vals) diff --git a/mtenn/config.py b/mtenn/config.py index c88eff2..a8d6eef 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -18,7 +18,9 @@ from pydantic import BaseModel, Field, root_validator import random from typing import Callable, ClassVar -import mtenn +import mtenn.combination +import mtenn.readout +import mtenn.model import numpy as np import torch @@ -126,14 +128,10 @@ class CombinationConfig(StringEnum): * mean: :py:class:`MeanCombination ` * max: :py:class:`MaxCombination ` - - * boltzmann: - :py:class:`BoltzmannCombination ` """ mean = "mean" max = "max" - boltzmann = "boltzmann" class ModelConfigBase(BaseModel): @@ -269,10 +267,8 @@ def build(self) -> mtenn.model.Model: mtenn_combination = mtenn.combination.MeanCombination() case CombinationConfig.max: mtenn_combination = mtenn.combination.MaxCombination( - neg=self.max_comb_neg, scale=self.max_comb_scale + negate_preds=self.max_comb_neg, pred_scale=self.max_comb_scale ) - case CombinationConfig.boltzmann: - mtenn_combination = mtenn.combination.BoltzmannCombination() case None: mtenn_combination = None diff --git a/mtenn/model.py b/mtenn/model.py index f5a81ba..51f99d0 100644 --- a/mtenn/model.py +++ b/mtenn/model.py @@ -260,12 +260,14 @@ def forward(self, input_list): # Separate out param names and params param_names, model_params = zip(*self.named_parameters()) - comb_pred = self.combination(pred_list, grad_dict, param_names, *model_params) + comb_pred, comb_pred_list = self.combination( + pred_list, grad_dict, param_names, *model_params + ) if self.comb_readout: - return self.comb_readout(comb_pred), pred_list + return self.comb_readout(comb_pred), comb_pred_list else: - return comb_pred, pred_list + return comb_pred, comb_pred_list class LigandOnlyModel(Model): diff --git a/mtenn/tests/test_combination.py b/mtenn/tests/test_combination.py index 8d7a72a..222fe9b 100644 --- a/mtenn/tests/test_combination.py +++ b/mtenn/tests/test_combination.py @@ -3,7 +3,7 @@ import pytest import torch -from mtenn.combination import MeanCombination, MaxCombination, BoltzmannCombination +from mtenn.combination import MeanCombination, MaxCombination from mtenn.conversion_utils.schnet import SchNet @@ -73,7 +73,7 @@ def test_max_combination(models_and_inputs): model_test, grouped=True, strategy="complex", - combination=MaxCombination(neg=False, scale=1.0), + combination=MaxCombination(negate_preds=False, pred_scale=1.0), ) # Test GroupedModel @@ -89,33 +89,3 @@ def test_max_combination(models_and_inputs): for n, p in model_test.named_parameters() ] ) - - -def test_boltzmann_combination(models_and_inputs): - model_test, model_ref, inp_list, target, loss_func = models_and_inputs - - # Ref calc - pred_list = torch.stack([model_ref(X)[0] for X in inp_list]) - w = torch.exp(-pred_list - torch.logsumexp(-pred_list, axis=0)) - pred_ref = torch.dot(w.flatten(), pred_list.flatten()) - loss = loss_func(pred_ref, target) - loss.backward() - - # Finish setting up GroupedModel - model_test = SchNet.get_model( - model_test, grouped=True, strategy="complex", combination=BoltzmannCombination() - ) - - # Test GroupedModel - pred_test, _ = model_test(inp_list) - loss = loss_func(pred_test, target) - loss.backward() - - # Compare - ref_param_dict = dict(model_ref.named_parameters()) - assert all( - [ - np.allclose(p.grad, ref_param_dict[n].grad, atol=5e-7) - for n, p in model_test.named_parameters() - ] - )