Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support higher dimension outputs in TorchMD-Net #198

Open
RaulPPelaez opened this issue Jun 30, 2023 · 14 comments
Open

Support higher dimension outputs in TorchMD-Net #198

RaulPPelaez opened this issue Jun 30, 2023 · 14 comments

Comments

@RaulPPelaez
Copy link
Collaborator

Right now the TorchMD_Net module is tied to the idea that only one thing (and maybe its derivative) is returned:

With TensorNet it might be interesting to have more outputs. Say a scalar (energies), minus its derivative (forces) and some tensor feature.

I would like to discuss how to include this functionality while leaving the current TorchMD_Net interface as unmodified as possible.

A way to do this would be:

  • Change the return type of TorchMD_Net to be an Union[Tensor, List[Tensor]]. derivative, if true, only diffs the first tensor.
    ) -> Tuple[Tensor, Optional[Tensor]]:

    The consequences of this is a possible TorchScript nightmare
  • Change the output models to take in a List[Tensor] for v. For instance here:
    def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):

cc @guillemsimeon

@peastman
Copy link
Collaborator

This feature is needed for some things I want to work on. It's also closely related to #26.

Here's my particular situation. I want the model to predict charges for every atom. Those charges will get factored into computing a Coulomb energy term. This requires both multiple outputs and multiple loss terms. Training will likely happen in multiple stages. For example, a possible protocol would be

Stage 1. Use fixed charges in computing the Coulomb energy. There are two loss terms: one that matches the output energy to the value in the dataset, and one that matches the atomic charges to values found in the dataset. That means the two output heads are being trained independently, though of course much of the model is shared between them.

Stage 2. Use the predicted charges for computing the Coulomb energy. The same two loss terms are used. The predicted charges now affect both terms.

Stage 3. Drop the loss term for the charges, and fine tune the model based only on energy.

Once the model is trained, we will often only be interested in the energy output, which is used for running simulations. But we might sometimes be interested in the predicted charges too. For example, they could be used for parametrizing a molecule for a conventional MD simulation.

How could this be implemented?

@giadefa
Copy link
Contributor

giadefa commented Sep 1, 2023 via email

@peastman
Copy link
Collaborator

peastman commented Sep 1, 2023

This is more complex to do but you could recompute the dataset every x epochs updating the priors.

The dataset isn't changing.

If I understand correctly the head for the charges would not change but the rest of the network would still learn.

It would still keep learning. The difference is that the charges would be optimized solely based on how they affect the energy accuracy, not on how well they match charges listed in the dataset.

@RaulPPelaez
Copy link
Collaborator Author

The current TorchMD_Net module does not allow for something like this, but you could maybe get away with it by leveraging output modules. The TorchMD_Net module is mainly a combination of a representation model (equivariant transformer, tensornet,...) which outputs a tensor of shape (Natoms, hidden_channels) and an OutputModel (which we could call a "head"), which in the case of Scalar takes this to (Nbatch, 1)
See e.g TN:

self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype)

x = self.act(self.linear((x)))
return x, None, z, pos, batch

Then TorchMD_Net sends that to an OutputModel:

x, v, z, pos, batch = self.representation_model(z, pos, batch, q=q, s=s)

Here x is (Natoms, hidden_channels)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)

Here x is (Natoms, 1)
# aggregate atoms
x = self.output_model.reduce(x, batch)

Here x is (Nbatch, 1)
The current Scalar output model simply reduces from (Natoms, hidden_channels) to (Nbatch,1). This is done in two steps, first pre_reduces goes from (Natoms, hidden_channels) to (Natoms, 1) with a MLP and then reduce goes from (Natoms, 1) to (Nbatches,1) with a scatter:
class OutputModel(nn.Module, metaclass=ABCMeta):
def __init__(self, allow_prior_model, reduce_op):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
self.reduce_op = reduce_op
def reset_parameters(self):
pass
@abstractmethod
def pre_reduce(self, x, v, z, pos, batch):
return
def reduce(self, x, batch):
return scatter(x, batch, dim=0, reduce=self.reduce_op)
def post_reduce(self, x):
return x
class Scalar(OutputModel):
def __init__(
self,
hidden_channels,
activation="silu",
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float
):
super(Scalar, self).__init__(
allow_prior_model=allow_prior_model, reduce_op=reduce_op
)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
return self.output_network(x)

You could write a new outputmodule that just makes TorchMD_Net return (Natoms, 2), interpreting the first value as energy and the second as charge:

class TwoScalar(OutputModel):
    def __init__(
        self,
        hidden_channels,
        activation="silu",
        allow_prior_model=True,
        reduce_op="sum",
        dtype=torch.float
    ):
        super(TwoScalar, self).__init__(
            allow_prior_model=allow_prior_model, reduce_op=reduce_op
        )
        act_class = act_class_mapping[activation]
        self.output_network = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
            act_class(),
            nn.Linear(hidden_channels // 2, 2, dtype=dtype),
        )

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.output_network[0].weight)
        self.output_network[0].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.output_network[2].weight)
        self.output_network[2].bias.data.fill_(0)
    def reduce(x, batch):
        return x 
    def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch):
        return self.output_network(x)

The limitation here would be that the Lightning module (LLNP) used in the Trainer does not let you customize the loss function, so you loose the ability to use the sum of all energies in the module for the loss. That means you would have to either modify LLNP or write your own custom training loop.

You could instead have two output modules, one that outputs batch energy (just current Scalar) and another one that outputs atom charge.

You would need to be careful then to share the weights of the representation model between the different stages of training in your protocol. You would also need to make the Datamodule aware of this by sometimes sending batch energy and sometimes atom charge in what is called "y".

OTOH, unless I am missing something, this kind of thing would be simpler to implement if we write a new module that can output a list of tensors, and we allow for customization of the loss function in the Trainer.

BTW, I agree with Gianni, Stage 1-2 train the base model and the heads, but then in Stage 3 if you freeze the charge head and keep training only the base you invalidate the charge head. But maybe I am not getting the protocol.

@peastman
Copy link
Collaborator

peastman commented Sep 4, 2023

this kind of thing would be simpler to implement if we write a new module that can output a list of tensors, and we allow for customization of the loss function in the Trainer.

That sounds to me like the right solution.

if you freeze the charge head and keep training only the base you invalidate the charge head

It wouldn't freeze anything, just change the loss function.

@peastman
Copy link
Collaborator

I really want to get started on this, creating a model that predicts partial charges and computes a Coulomb energy based on them. I could probably hack something in as described above, but I think it would be easier to just implement the proper architecture to do it cleanly. Here is a proposal for how it could work.

We allow a model to have multiple output modules. Each one gets is own list of priors. That means these lines turn into a loop over output modules:

# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
# scale by data standard deviation
if self.std is not None:
x = x * self.std
# apply atom-wise prior model
if self.prior_model is not None:
for prior in self.prior_model:
x = prior.pre_reduce(x, z, pos, batch, extra_args)
# aggregate atoms
x = self.output_model.reduce(x, batch)
# shift by data mean
if self.mean is not None:
x = x + self.mean
# apply output model after reduction
y = self.output_model.post_reduce(x)
# apply molecular-wise prior model
if self.prior_model is not None:
for prior in self.prior_model:
y = prior.post_reduce(y, z, pos, batch, extra_args)

Each output module eventually outputs a scalar, the same as now. All of them get summed together. Effectively each one computes a contribution to the energy, and the total energy is their sum.

In addition, we allow output modules and priors to define their own loss functions. Each pre_reduce() and post_reduce() method can return a scalar representing a contribution to the loss. The optimization is done on the sum of all losses, the one computed in LNNP._compute_losses() and any others computed along the way by output modules or priors.

@RaulPPelaez
Copy link
Collaborator Author

Hi Peter, thanks for the input. This has been in my head for some time but I have not come up with a clean solution. If I understand your approach, allowing the output models to compute parts of the loss would require the output models to be connected to the Datamodule in order to have access to the reference values. This feels awkward if one is using the model for inference. I am not sure it is a good idea to let the model know about the training process like that.

I worked on this but eventually dropped because I stumbled upon several walls:
1- If the model can output several quantities, the user has to decide which of them need the derivative. I am sure someone will eventually want to go beyond the current usage of "derivative" as "minus derivative of y with respect to positions".
2- There can be per-atom and per-molecule quantities, and they could be scalar or any other dimension. If the model returns two scalars, how does the output model which one to interpret as energy and which as charge?
3- The only way I can think of to compute the losses for such a model in LNNP is by making the names of the return values agree in the model output and the Datamodule. If the model returns something called "q" but the Datamodule does not provide it you cannot compute the loss. But then you might need to compute the loss in a special way for each one, which eventually leads to connecting the Datamodule and the Output Model. This gets messy in the code fast, at least in my attempts.

We need to come up with some restrictions. For instance, simply allowing a set of predefined optional outputs (charge, spin, ...) would ease things a lot.

For the time being I can add support for another scalar per-atom, "q", for which the derivative is not needed. Maybe this is enough for your current usecase?

@peastman
Copy link
Collaborator

I think those issues are largely avoided by distinguishing between the output module's main output (y corresponding to energy) and any extra outputs. Every output module produces an energy contribution. LNNP._compute_losses() only looks at y and ignores all the others. The model isn't expected to calculate derivatives for any other outputs, though of course you can do it yourself at inference time by calling backward() on another output.

The Coulomb output head is still basically a way of computing an energy. It just does it in a more complicated way.

Regarding other loss terms, what about creating an abstract Loss class for computing loss terms, and allowing the model to register custom Loss subclasses?

@RaulPPelaez
Copy link
Collaborator Author

Every output module produces an energy contribution. LNNP._compute_losses() only looks at y and ignores all the others.

But you need to take every output into account for the loss, right? Say, for instance, I want a head that predicts some kind of charge besides energy/forces.

Regarding other loss terms, what about creating an abstract Loss class for computing loss terms, and allowing the model to register custom Loss subclasses?

I find it reassuring that you are coming up with the same ideas I did hehe

I like this idea of having energy and force always present and consider the rest as "extra", I think I can cook something up. Let me PR with some draft and we move the discussion there.

Roughly it could go like this:

    import torch.nn as nn

  class BaseHead(nn.Module):
      def atomwise(self, atom_features, results):
          return atom_features, results

      def moleculewise(self, molecule_features, results):
          return molecule_features, results
  # Define a Head class that modifies the "charge" entry in the dictionary.
  class ChargeHead(BaseHead):
      def atomwise(self, atom_features, results):
          # Check if "charge" is in results, if not initialize it
          if "charge" not in results:
              results["charge"] = torch.zeros(atom_features.shape[0])  # or some other initialization
          # Modify the "charge" entry based on the features
          results["charge"] += self.some_nn(atom_features)
          return atom_features, results

  class TorchMD_Net(nn.Module):
      def __init__(self, head_list, **kwargs):
          super(MainModel, self).__init__()
          ...
          self.head_list = nn.ModuleList(head_list)

      def forward(self, z, pos, batch, **kwargs):
          results = {}  # Initialize an empty results dictionary
          atom_features = self.representation_model(z, pos, batch, **kwargs)

          # Pass the features through each head in the list
          for head in self.head_list:
              atom_features, results = head.atomwise(atom_features, results)
          molecule_features = self.reduce(atom_features)
          for head in self.head_list:
              molecule_features, results = head.moleculewise(molecule_features, results)
           return results

Each head can create new outputs or add to existing ones (for instance priors add to the energy).
The TorchMD_Net model returns a dict with arbitrary entries, at least "energy" (or "y" like thus far) would be there.

Then the loss computation expects the dataloader to provide the same elements as the ones returned by the model:

  def _compute_losses(self, outputs, batch, loss_fn, stage):
    """
    Compute the losses for each model output.

    Args:
        outputs: Dictionary of model outputs.
        batch: Batch of data.
        loss_fn: Loss function to compute.
        stage: Current training stage.

    Returns:
        losses: Dictionary of computed losses for each model output.
    """
    losses = {}
    loss_name = loss_fn.__name__
    for key in outputs:
        if key in batch:
            loss = loss_fn(outputs[key], getattr(batch, key))           
            loss = self._update_loss_with_ema(stage, key, loss_name, loss)           
            losses[key] = loss
        else:
            raise ValueError(f"Reference values for '{key}' are missing in the batch")

    return losses

I guess exceptions would be if the user has provided a zero weight for that value. The user would provide a set of weights similarly to how "y_weight" and "neg_dy_weight" are used now.
Currently the total loss is computed as:

            total_loss = (
                step_losses["y"] * self.hparams.y_weight
                + step_losses["neg_dy"] * self.hparams.neg_dy_weight
            )

We could follow that line and just keep doing a weighted sum or give some more freedom to the user via some Loss class abstraction.

@peastman
Copy link
Collaborator

Yes, I think that's the idea. A few clarifications.

Are atomwise() and moleculewise() the same as pre_reduce() and post_reduce()? That could be confusing: a sample could represent a single molecule or many molecules. And the input positions won't necessarily be atoms, for example in a coarse grained model. Maybe per_point() and per_sample(), or something like that?

We need a bit more flexibility in computing the loss than what you show. You might want different loss functions for different outputs. For example, if you predict dipole moments the loss might be based on a dot product, not per-element differences. Some outputs may not directly appear in the loss function at all, and there may not be a corresponding element in the dataset. For example, if you predict charges and use them to compute an energy, you wouldn't necessarily have an explicit loss term for charges. But you still want them as an output, because the user might want to use the model to predict charges.

@RaulPPelaez
Copy link
Collaborator Author

the input positions won't necessarily be atoms, for example in a coarse grained model. Maybe per_point() and per_sample()

Love it! "pre_reduce" and "post_reduce" do not feel right to me :P

For example, if you predict charges and use them to compute an energy, you wouldn't necessarily have an explicit loss term for charges

I get it now, makes sense.

@peastman
Copy link
Collaborator

peastman commented Feb 7, 2024

I'm finally getting back to this. I started trying to implement the idea I suggested in #239 (comment), but I quickly reached two conclusions. First, there are some big questions that really need to be answered first. Second, a minimal change to the interface isn't going to get us where we want to be. Here is a concrete proposal for a new design. It makes significant changes to both the code and the configuration file format. In the long term, I think the flexibility it gains will be worth the effort. It will also hopefully allow some significant simplifications to the code.

I propose that every model can have

  • Multiple output values, such as energy and atomic charges, with different outputs possibly having different shapes.
  • Multiple output heads. Each head produces a single output value. Multiple heads may produce the same output, in which case they are added together.
  • Multiple loss terms, which all get added together (after multiplying by weights).

In the simple case of a model with a single output head, the configuration might look like

output_head: scalar
  name: energy

A more complicated model might have multiple outputs:

output_head:
  - scalar
    name: energy
  - atom_scalar
    name: charge

In this example, scalar outputs a scalar for each sample, while atom_scalar outputs a scalar for each atom of each sample.

I believe this design could possibly eliminate the need for priors. They would just be output heads.

output_head:
  - scalar
    name: energy
  - ZBL
    name: energy
    cutoff_distance: 4.0
    max_num_neighbors: 50

Because the two output heads both produce the same output (energy), they get added together. Note that priors currently use an API that in principle is more general than this: they take in the current value of y and output a new value of y. But in practice, all of them just add a new term to the current value.

Loss terms would be defined in a similar way. A simple case might compute the L2 loss, comparing the "energy" output to the "y" field of the dataset.

loss: L2
  output: energy
  dataset_field: y

A slightly more complicated one might include loss terms for both energy and force, and perhaps other values as well.

loss:
  - L2
    output: energy
    dataset_field: y
    weight: 1.0
  - L2
    output_neg_gradient: energy
    dataset_field: force
    weight: 0.1
  - L2
    output: charge
    dataset_field: charge
    weight: 0.01

Outputs and dataset fields can have arbitrary names. Each L2 loss term is calculated by comparing an output (or its gradient) and a field of the dataset.

In terms of the implementation, these are some of the main changes.

  • The output_model field of TorchMD_Net would change to output_models, a list.
  • The OutputModel class would gain a name field with the name of the value it outputs.
  • We would create a new Loss class and appropriate subclasses, at least L1 and L2 and possibly others.
  • The LNNP class would gain a field with the list of Loss objects.
  • The existing priors would be converted to OutputModels and we would get rid of all the other code related to priors.
  • TorchMD_Net.forward() would return a dict mapping output names to values, instead of a tuple of fixed values as it currently does.
  • We could maintain compatibility with old config files by having a routine that automatically converts an old style configuration to an equivalent one in the new style.

Notice that the losses are a property of the LightningModule, which is used only in training. They don't have any parameters that are saved in checkpoints. That means you can train for a while, modify your losses, and continue training from the last checkpoint.

@RaulPPelaez
Copy link
Collaborator Author

I believe this is the natural evolution of this project and we should go for it.
I have some concerns about your design:

Dict order in input:

output_head:
  - scalar
    name: energy
  - ZBL
    name: energy
    cutoff_distance: 4.0
    max_num_neighbors: 50

The default yaml reader does not guarantee a particular order AFAIK, neither when reading or writing. Perhaps there should be another field in each output like "after: [name of previous layer]". So:

output_head:
  - output1
    type: scalar
    field_name: energy
  - prior1
    type: ZBL
    field_name: energy
    after: output1
    cutoff_distance: 4.0
    max_num_neighbors: 50

Computing loss using the gradients

If I got it right, when you have a loss computed as the gradient of the energy it is implicitly with respect to the positions and it is always the negative gradient. I feel it would be more generic and expressive if the force would be also an output model. For instance:

output_head:
  - scalar
    name: energy
  - neg_gradient
    name: force
    on: energy
    respect_to: pos

This would also have the benefit of generating a model that outputs energy and forces, instead of the user having to call energy.backwards() during inference.

TorchScript

I hit a wall when trying to implement something like the design you propose because jit.script did not liked TorchMD_Net returning a Dict[str, Tensor]. Perhaps the situation is different now but we have to be careful with TorchScript when dealing with this kind of dynamic code.

Backwards compatibility

With some magic and maybe parameter file versioning we can translate from current to new interface when reading parameters, but I do not think we can ensure checkpoints are compatible. And also TorchScript models saved with current would probably not be loadable with the new model.

Perhaps something can be done about it (there is already some name-changing shennanigans in load_model regarding priors) but it feels like quite an undertaking that would also become a burden on the codebase.
Additionally, it feels wrong to explain in the documentation about there being "two possible formats for the parameter files", amounting to having essentially two names for some functionality.
If we agree on such different API IMO we might as well call it TorchMD-Net v2.0 and deprecate the current one now. Then in the documentation there can be a note about this like "use v1.x if you have an old checkpoint".

I feel it would be easier to backport bugfixes to previous versions of TorchMD-Net than to support two wildly different APIs at the same time.

@peastman
Copy link
Collaborator

peastman commented Feb 7, 2024

Perhaps there should be another field in each output like "after: [name of previous layer]".

In most cases, I don't think the order matters? In that example, there are two outputs that produce energies and get added together. The result doesn't depend on which order they're computed in. You're right that it can matter in some cases, like if you predict charges and use them to compute a Coulomb energy. Maybe a depends field to specify a dependency on another output? Something like this.

output_head:
  - atom_scalar
    name: charge
  - coulomb
    name: energy
    depends: charge

and it is always the negative gradient

Both positive and negative gradients should be supported, just with a different property name (e.g. gradient or neg_gradient).

jit.script did not liked TorchMD_Net returning a Dict[str, Tensor].

It think it works ok now. AIMNet2 returns a dict, and I was able to wrap it for OpenMM-ML without any problems. openmm/openmm-ml#64

I do not think we can ensure checkpoints are compatible.

It might not be an issue. I think a checkpoint just stores the values of all model parameters. Nothing in this proposal should change the set of parameters in a model, just the code that constructs the model in the first place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants