Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OCPCalculator fails on equiformer v2 #950

Open
caic99 opened this issue Dec 24, 2024 · 2 comments
Open

OCPCalculator fails on equiformer v2 #950

caic99 opened this issue Dec 24, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@caic99
Copy link

caic99 commented Dec 24, 2024

Python version

3.11.0

fairchem-core version

1.2.0

pytorch version

2.2.0

cuda version

12.1

Operating system version

Ubuntu 22.04.4 LTS

Minimal example

from fairchem.core import OCPCalculator
from ase.io import read, write

structure=read("COLL_train.cif")
structure.calc=OCPCalculator(
            checkpoint_path="eqV2_153M_omat_mp_salex.pt",
            cpu=False
        )
f=structure.get_forces()

Current behavior

For all three attached structures, the error is:

RuntimeError: The expanded size of the tensor (10) must match the existing size (11) at non-singleton dimension 0. Target sizes: [10]. Tensor sizes: [11]

with each structure have different target sizes, and tensor sizes = target sizes + 1 .

Traceback

In [1]: from fairchem.core import OCPCalculator
   ...: from ase.io import read, write
   ...:
   ...: structure=read("COLL_train.cif")
   ...: structure.calc=OCPCalculator(
   ...:             checkpoint_path="eqV2_153M_omat_mp_salex.pt",
   ...:             cpu=False
   ...:         )
   ...: f=structure.get_forces()
WARNING:root:No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 9
      4 structure=read("COLL_train.cif")
      5 structure.calc=OCPCalculator(
      6             checkpoint_path="eqV2_153M_omat_mp_salex.pt",
      7             cpu=False
      8         )
----> 9 f=structure.get_forces()

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/ase/atoms.py:812, in Atoms.get_forces(self, apply_constraint, md)
    810 if self._calc is None:
    811     raise RuntimeError('Atoms object has no calculator.')
--> 812 forces = self._calc.get_forces(self)
    814 if apply_constraint:
    815     # We need a special md flag here because for MD we want
    816     # to skip real constraints but include special "constraints"
    817     # Like Hookean.
    818     for constraint in self.constraints:

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/ase/calculators/abc.py:30, in GetPropertiesMixin.get_forces(self, atoms)
     29 def get_forces(self, atoms=None):
---> 30     return self.get_property('forces', atoms)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/ase/calculators/calculator.py:538, in BaseCalculator.get_property(self, name, atoms, allow_calculation)
    535     if self.use_cache:
    536         self.atoms = atoms.copy()
--> 538     self.calculate(atoms, [name], system_changes)
    540 if name not in self.results:
    541     # For some reason the calculator was not able to do what we want,
    542     # and that is OK.
    543     raise PropertyNotImplementedError(
    544         '{} not present in this ' 'calculation'.format(name)
    545     )

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/fairchem/core/common/relaxation/ase_utils.py:231, in OCPCalculator.calculate(self, atoms, properties, system_changes)
    228 data_object = self.a2g.convert(atoms)
    229 batch = data_list_collater([data_object], otf_graph=True)
--> 231 predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True)
    233 for key in predictions:
    234     _pred = predictions[key]

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/fairchem/core/trainers/ocp_trainer.py:452, in OCPTrainer.predict(self, data_loader, per_image, results_file, disable_tqdm)
    444 for _, batch in tqdm(
    445     enumerate(data_loader),
    446     total=len(data_loader),
   (...)
    449     disable=disable_tqdm,
    450 ):
    451     with torch.cuda.amp.autocast(enabled=self.scaler is not None):
--> 452         out = self._forward(batch)
    454     for target_key in self.config["outputs"]:
    455         pred = self._denorm_preds(target_key, out[target_key], batch)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/fairchem/core/trainers/ocp_trainer.py:245, in OCPTrainer._forward(self, batch)
    244 def _forward(self, batch):
--> 245     out = self.model(batch.to(self.device))
    247     outputs = {}
    248     batch_size = batch.natoms.numel()

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/fairchem/core/models/base.py:317, in HydraModel.forward(self, data)
    313 with torch.autocast(
    314     device_type=self.device, enabled=self.output_heads[k].use_amp
    315 ):
    316     if self.pass_through_head_outputs:
--> 317         out.update(self.output_heads[k](data, emb))
    318     else:
    319         out[k] = self.output_heads[k](data, emb)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py:333, in Rank2SymmetricTensorHead.forward(self, data, emb)
    330 x_edge = self.xedge_layer_norm(x_edge)
    332 if self.decompose:
--> 333     tensor_0, tensor_2 = self.block(
    334         graph.edge_distance_vec, x_edge, graph.edge_index[1], data
    335     )
    337     if self.block.extensive:  # legacy, may be interesting to try
    338         tensor_0 = tensor_0 / self.avg_num_nodes

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py:206, in Rank2DecompositionEdgeBlock.forward(self, edge_distance_vec, x_edge, edge_index, data)
    203     for module in self.irrep2_MLP:
    204         node_irrep2 = module(node_irrep2)
--> 206 scalar = scatter(
    207     node_scalar.view(-1),
    208     data.batch,
    209     dim=0,
    210     reduce="sum" if self.extensive else "mean",
    211 )
    212 irrep2 = scatter(
    213     node_irrep2.view(-1, 5),
    214     data.batch,
    215     dim=0,
    216     reduce="sum" if self.extensive else "mean",
    217 )
    219 # Note (@abhshkdz): If we have separate normalizers on the isotropic and
    220 # anisotropic components (implemented in the trainer), combining the
    221 # scalar and irrep2 predictions here would lead to the incorrect result.
    222 # Instead, we should combine the predictions after the normalizers.

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch_scatter/scatter.py:156, in scatter(src, index, dim, out, dim_size, reduce)
    154     return scatter_mul(src, index, dim, out, dim_size)
    155 elif reduce == 'mean':
--> 156     return scatter_mean(src, index, dim, out, dim_size)
    157 elif reduce == 'min':
    158     return scatter_min(src, index, dim, out, dim_size)[0]

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch_scatter/scatter.py:41, in scatter_mean(src, index, dim, out, dim_size)
     38 def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
     39                  out: Optional[torch.Tensor] = None,
     40                  dim_size: Optional[int] = None) -> torch.Tensor:
---> 41     out = scatter_sum(src, index, dim, out, dim_size)
     42     dim_size = out.size(dim)
     44     index_dim = dim

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch_scatter/scatter.py:11, in scatter_sum(src, index, dim, out, dim_size)
      8 def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
      9                 out: Optional[torch.Tensor] = None,
     10                 dim_size: Optional[int] = None) -> torch.Tensor:
---> 11     index = broadcast(index, src, dim)
     12     if out is None:
     13         size = list(src.size())

File /mnt/data_nas/public/Miniconda/envs/ase_interface/lib/python3.11/site-packages/torch_scatter/utils.py:12, in broadcast(src, other, dim)
     10 for _ in range(src.dim(), other.dim()):
     11     src = src.unsqueeze(-1)
---> 12 src = src.expand(other.size())
     13 return src

RuntimeError: The expanded size of the tensor (10) must match the existing size (11) at non-singleton dimension 0.  Target sizes: [10].  Tensor sizes: [11]

Expected Behavior

It works well on other datasets.

Relevant files to reproduce this bug

Since GitHub does not allow uploading .cif files directly, I've added a .txt suffix.

H_nature_2022.cif.txt
CGM_MLP_NC2023.cif.txt
COLL_train.cif.txt


torch_scatter version: 2.1.2

@caic99 caic99 added the bug Something isn't working label Dec 24, 2024
@lbluque
Copy link
Collaborator

lbluque commented Dec 24, 2024

Hello @caic99,

You need version >=1.3.0 for onat24 models Can you please upgrade and try running your code again?

@caic99
Copy link
Author

caic99 commented Dec 25, 2024

Hi @lbluque ,
I've updated fairchem-core to 1.4.0 . Unluckily the same error happens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants