From 2963a90a004d8c68a3e7fe79b3b2acce1a85243a Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 2 Aug 2024 12:42:53 -0600 Subject: [PATCH 01/10] rank 2 tensor head --- .../models/equiformer_v2/equiformer_v2.py | 4 +- .../prediction_heads/__init__.py | 0 .../equiformer_v2/prediction_heads/rank2.py | 338 ++++++++++++++++++ 3 files changed, 340 insertions(+), 2 deletions(-) create mode 100644 src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py create mode 100644 src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index e2625eadaf..556814c35e 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -395,7 +395,7 @@ def __init__( requires_grad=False, ) - self.apply(self._init_weights) + self.apply(self.init_weights) self.apply(self._uniform_init_rad_func_linear_weights) def _init_gp_partitions( @@ -625,7 +625,7 @@ def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): def num_params(self): return sum(p.numel() for p in self.parameters()) - def _init_weights(self, m): + def init_weights(self, m): if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): if m.bias is not None: torch.nn.init.constant_(m.bias, 0) diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py new file mode 100644 index 0000000000..d85896503b --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -0,0 +1,338 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import torch +from e3nn import o3 +from torch import nn +from torch_scatter import scatter + +from fairchem.core.common.registry import registry +from fairchem.core.models.base import BackboneInterface, HeadInterface +from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer + + +class Rank2Block(nn.Module): + """ + Output block for predicting rank-2 tensors (stress, dielectric tensor). + Applies outer product between edges and computes node-wise or edge-wise MLP. + + Args: + emb_size (int): Size of edge embedding used to compute outer product + num_layers (int): Number of layers of the MLP + edge_level (bool): Apply MLP to edges' outer product + extensive (bool): Whether to sum or average the outer products + """ + + def __init__( + self, + emb_size: int, + num_layers: int = 2, + edge_level: bool = False, + extensive: bool = False, + ): + super().__init__() + + self.edge_level = edge_level + self.emb_size = emb_size + self.extensive = extensive + self.scalar_nonlinearity = nn.SiLU() + self.r2tensor_MLP = nn.Sequential() + for i in range(num_layers): + if i < num_layers - 1: + self.r2tensor_MLP.append(nn.Linear(emb_size, emb_size)) + self.r2tensor_MLP.append(self.scalar_nonlinearity) + else: + self.r2tensor_MLP.append(nn.Linear(emb_size, 1)) + + def forward(self, edge_distance_vec, x_edge, edge_index, data): + """ + Args: + edge_distance_vec (torch.Tensor): Tensor of shape (..., 3) + x_edge (torch.Tensor): Tensor of shape (..., emb_size) + edge_index (torch.Tensor): Tensor of shape (2, nEdges) + data: LMDBDataset sample + """ + + outer_product_edge = torch.bmm( + edge_distance_vec.unsqueeze(2), edge_distance_vec.unsqueeze(1) + ) + + edge_outer = ( + x_edge[:, :, None] * outer_product_edge.view(-1, 9)[:, None, :] + ) # should end up as 2400 x 128 x 9 + + # edge_outer: (nEdges, emb_size_edge, 9) + if self.edge_level: + # MLP at edge level before pooling. + edge_outer = edge_outer.transpose(1, 2) # (nEdges, 9, emb_size_edge) + edge_outer = self.r2tensor_MLP(edge_outer) # (nEdges, 9, 1) + edge_outer = edge_outer.reshape(-1, 9) # (nEdges, 9) + + node_outer = scatter(edge_outer, edge_index, dim=0, reduce="mean") + else: + # operates at edge level before mixing / MLP => mixing / MLP happens at node level + node_outer = scatter(edge_outer, edge_index, dim=0, reduce="mean") + + node_outer = node_outer.transpose(1, 2) # (natoms, 9, emb_size_edge) + node_outer = self.r2tensor_MLP(node_outer) # (natoms, 9, 1) + node_outer = node_outer.reshape(-1, 9) # (natoms, 9) + + # node_outer: nAtoms, 9 => average across all atoms at the structure level + if self.extensive: + stress = scatter(node_outer, data.batch, dim=0, reduce="sum") + else: + stress = scatter(node_outer, data.batch, dim=0, reduce="mean") + return stress + + +class Rank2DecompositionEdgeBlock(nn.Module): + """ + Output block for predicting rank-2 tensors (stress, dielectric tensor, etc). + Decomposes a rank-2 symmetric tensor into irrep degree 0 and 2. + + Args: + emb_size (int): Size of edge embedding used to compute outer product + num_layers (int): Number of layers of the MLP + edge_level (bool): Apply MLP to edges' outer product + extensive (bool): Whether to sum or average the outer products + """ + + def __init__( + self, + emb_size: int, + num_layers: int = 2, + edge_level: bool = False, + extensive: bool = False, + ): + super().__init__() + self.emb_size = emb_size + self.edge_level = edge_level + self.extensive = extensive + self.scalar_nonlinearity = nn.SiLU() + self.scalar_MLP = nn.Sequential() + self.irrep2_MLP = nn.Sequential() + for i in range(num_layers): + if i < num_layers - 1: + self.scalar_MLP.append(nn.Linear(emb_size, emb_size)) + self.irrep2_MLP.append(nn.Linear(emb_size, emb_size)) + self.scalar_MLP.append(self.scalar_nonlinearity) + self.irrep2_MLP.append(self.scalar_nonlinearity) + else: + self.scalar_MLP.append(nn.Linear(emb_size, 1)) + self.irrep2_MLP.append(nn.Linear(emb_size, 1)) + + # Change of basis obtained by stacking the C-G coefficients + self.change_mat = torch.transpose( + torch.tensor( + [ + [3 ** (-0.5), 0, 0, 0, 3 ** (-0.5), 0, 0, 0, 3 ** (-0.5)], + [0, 0, 0, 0, 0, 2 ** (-0.5), 0, -(2 ** (-0.5)), 0], + [0, 0, -(2 ** (-0.5)), 0, 0, 0, 2 ** (-0.5), 0, 0], + [0, 2 ** (-0.5), 0, -(2 ** (-0.5)), 0, 0, 0, 0, 0], + [0, 0, 0.5**0.5, 0, 0, 0, 0.5**0.5, 0, 0], + [0, 2 ** (-0.5), 0, 2 ** (-0.5), 0, 0, 0, 0, 0], + [ + -(6 ** (-0.5)), + 0, + 0, + 0, + 2 * 6 ** (-0.5), + 0, + 0, + 0, + -(6 ** (-0.5)), + ], + [0, 0, 0, 0, 0, 2 ** (-0.5), 0, 2 ** (-0.5), 0], + [-(2 ** (-0.5)), 0, 0, 0, 0, 0, 0, 0, 2 ** (-0.5)], + ] + ).detach(), + 0, + 1, + ) + + def forward(self, edge_distance_vec, x_edge, edge_index, data): + """ + Args: + edge_distance_vec (torch.Tensor): Tensor of shape (..., 3) + x_edge (torch.Tensor): Tensor of shape (..., emb_size) + edge_index (torch.Tensor): Tensor of shape (2, nEdges) + data: LMDBDataset sample + """ + # Calculate spherical harmonics of degree 2 of the points sampled + sphere_irrep2 = o3.spherical_harmonics( + 2, edge_distance_vec, True + ).detach() # (nEdges, 5) + + if self.edge_level: + # MLP at edge level before pooling. + + # Irrep 0 prediction + edge_scalar = x_edge + edge_scalar = self.scalar_MLP(edge_scalar) + + # Irrep 2 prediction + edge_irrep2 = ( + sphere_irrep2[:, :, None] * x_edge[:, None, :] + ) # (nEdges, 5, emb_size) + edge_irrep2 = self.irrep2_MLP(edge_irrep2) + + node_scalar = scatter(edge_scalar, edge_index, dim=0, reduce="mean") + node_irrep2 = scatter(edge_irrep2, edge_index, dim=0, reduce="mean") + else: + edge_irrep2 = ( + sphere_irrep2[:, :, None] * x_edge[:, None, :] + ) # (nAtoms, 5, emb_size) + + node_scalar = scatter(x_edge, edge_index, dim=0, reduce="mean") + node_irrep2 = scatter(edge_irrep2, edge_index, dim=0, reduce="mean") + + # Irrep 0 prediction + for module in self.scalar_MLP: + node_scalar = module(node_scalar) + + # Irrep 2 prediction + for module in self.irrep2_MLP: + node_irrep2 = module(node_irrep2) + + scalar = scatter( + node_scalar.view(-1), + data.batch, + dim=0, + reduce="sum" if self.extensive else "mean", + ) + irrep2 = scatter( + node_irrep2.view(-1, 5), + data.batch, + dim=0, + reduce="sum" if self.extensive else "mean", + ) + + # Note (@abhshkdz): If we have separate normalizers on the isotropic and + # anisotropic components (implemented in the trainer), combining the + # scalar and irrep2 predictions here would lead to the incorrect result. + # Instead, we should combine the predictions after the normalizers. + + return scalar.reshape(-1), irrep2 + + +@registry.register_model("rank2_symmetric_head") +class Rank2SymmetricTensorHead(nn.Module, HeadInterface): + """A rank 2 symmetric tensor prediction head. + + Attributes: + ouput_name: name of output prediction property (ie, stress) + sphharm_norm: layer normalization for spherical harmonic edge weights + xedge_layer_norm: embedding layer norm + block: rank 2 equivariant symmetric tensor block + """ + + def __init__( + self, + backbone: BackboneInterface, + output_name: str, + decompose: bool = False, + use_source_target_embedding_stress: bool = False, + extensive: bool = False, + ): + """ + Args: + backbone: Backbone model that the head is attached to + decompose: Wether to decompose the rank2 tensor into isotropic and anisotropic components + use_source_target_embedding_stress: Whether to use both source and target atom embeddings + extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive). + """ + super().__init__() + self.output_name = output_name + self.decompose = decompose + self.sphharm_norm = get_normalization_layer( + backbone.norm_type, + lmax=max(backbone.lmax_list), + num_channels=1, + ) + + if use_source_target_embedding_stress: + stress_sphere_channels = self.sphere_channels * 2 + else: + stress_sphere_channels = self.sphere_channels + + self.xedge_layer_norm = nn.LayerNorm(stress_sphere_channels) + + if decompose: + self.block = Rank2DecompositionEdgeBlock( + emb_size=stress_sphere_channels, + num_layers=2, + edge_level=self.edge_level_mlp_stress, + extensive=extensive, + ) + else: + self.block = Rank2Block( + emb_size=stress_sphere_channels, + num_layers=2, + edge_level=self.edge_level_mlp_stress, + extensive=extensive, + ) + + # initialize weights + self.block.apply(backbone.init_weights) + + def forward( + self, data: dict[str, torch.Tensor] | torch.Tensor, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """ + Args: + data: data batch + emb: dictionary with embedding object and graph data + + Returns: dict of {output property name: predicted value} + """ + node_emb, graph = emb["node_embedding"], emb["graph"] + + sphharm_weights_edge = o3.spherical_harmonics( + torch.arange(0, node_emb.lmax_list[-1] + 1).tolist(), + graph.edge_distance_vec, + False, + ).detach() + + # layer norm because sphharm_weights_edge values become large and causes infs with amp + sphharm_weights_edge = self.stress_sph_norm( + sphharm_weights_edge[:, :, None] + ).squeeze() + + if self.use_source_target_embedding_stress: + x_source = node_emb.expand_edge(graph.edge_index[0]).embedding + x_target = node_emb.expand_edge(graph.edge_index[1]).embedding + x_edge = torch.cat((x_source, x_target), dim=2) + else: + x_edge = node_emb.expand_edge(graph.edge_index[1]).embedding + + x_edge = torch.einsum("abc, ab->ac", x_edge, sphharm_weights_edge) + + # layer norm because x_edge values become large and causes infs with amp + x_edge = self.stress_xedge_layer_norm(x_edge) + + if self.decompose_stress: + tensor_0, tensor_2 = self.stress_block( + graph.edge_distance_vec, x_edge, graph.edge_index[1], data + ) + + if self.extensive: # legacy, may be interesting to try + tensor_0 = tensor_0 / self.avg_num_nodes + tensor_2 = tensor_2 / self.avg_num_nodes + + output = { + f"{self.output_name}_isotropic": tensor_0.unsqueeze(1), + f"{self.output_name}_anisotropic": tensor_2, + } + else: + stress = self.stress_block( + graph.edge_distance_vec, x_edge, graph.edge_index[1], data + ) + output = {self.output_name: stress.reshape((-1, 3))} + + return output From 91d38308a80ac1f3dc7f97645b55e0796d251d08 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 2 Aug 2024 14:01:11 -0600 Subject: [PATCH 02/10] fix rank2 head and add to e2e test --- .../equiformer_v2/prediction_heads/rank2.py | 43 ++++++++++--------- .../test_configs/test_equiformerv2_hydra.yml | 3 ++ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index d85896503b..aaa39ef228 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -85,10 +85,10 @@ def forward(self, edge_distance_vec, x_edge, edge_index, data): # node_outer: nAtoms, 9 => average across all atoms at the structure level if self.extensive: - stress = scatter(node_outer, data.batch, dim=0, reduce="sum") + r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="sum") else: - stress = scatter(node_outer, data.batch, dim=0, reduce="mean") - return stress + r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="mean") + return r2_tensor class Rank2DecompositionEdgeBlock(nn.Module): @@ -237,44 +237,47 @@ def __init__( backbone: BackboneInterface, output_name: str, decompose: bool = False, - use_source_target_embedding_stress: bool = False, + edge_level_mlp: bool = False, + use_source_target_embedding: bool = False, extensive: bool = False, ): """ Args: backbone: Backbone model that the head is attached to decompose: Wether to decompose the rank2 tensor into isotropic and anisotropic components - use_source_target_embedding_stress: Whether to use both source and target atom embeddings + use_source_target_embedding: Whether to use both source and target atom embeddings extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive). """ super().__init__() self.output_name = output_name self.decompose = decompose + self.use_source_target_embedding = use_source_target_embedding + self.sphharm_norm = get_normalization_layer( backbone.norm_type, lmax=max(backbone.lmax_list), num_channels=1, ) - if use_source_target_embedding_stress: - stress_sphere_channels = self.sphere_channels * 2 + if use_source_target_embedding: + r2_tensor_sphere_channels = backbone.sphere_channels * 2 else: - stress_sphere_channels = self.sphere_channels + r2_tensor_sphere_channels = backbone.sphere_channels - self.xedge_layer_norm = nn.LayerNorm(stress_sphere_channels) + self.xedge_layer_norm = nn.LayerNorm(r2_tensor_sphere_channels) if decompose: self.block = Rank2DecompositionEdgeBlock( - emb_size=stress_sphere_channels, + emb_size=r2_tensor_sphere_channels, num_layers=2, - edge_level=self.edge_level_mlp_stress, + edge_level=edge_level_mlp, extensive=extensive, ) else: self.block = Rank2Block( - emb_size=stress_sphere_channels, + emb_size=r2_tensor_sphere_channels, num_layers=2, - edge_level=self.edge_level_mlp_stress, + edge_level=edge_level_mlp, extensive=extensive, ) @@ -300,11 +303,11 @@ def forward( ).detach() # layer norm because sphharm_weights_edge values become large and causes infs with amp - sphharm_weights_edge = self.stress_sph_norm( + sphharm_weights_edge = self.sphharm_norm( sphharm_weights_edge[:, :, None] ).squeeze() - if self.use_source_target_embedding_stress: + if self.use_source_target_embedding: x_source = node_emb.expand_edge(graph.edge_index[0]).embedding x_target = node_emb.expand_edge(graph.edge_index[1]).embedding x_edge = torch.cat((x_source, x_target), dim=2) @@ -314,10 +317,10 @@ def forward( x_edge = torch.einsum("abc, ab->ac", x_edge, sphharm_weights_edge) # layer norm because x_edge values become large and causes infs with amp - x_edge = self.stress_xedge_layer_norm(x_edge) + x_edge = self.xedge_layer_norm(x_edge) - if self.decompose_stress: - tensor_0, tensor_2 = self.stress_block( + if self.decompose: + tensor_0, tensor_2 = self.block( graph.edge_distance_vec, x_edge, graph.edge_index[1], data ) @@ -330,9 +333,9 @@ def forward( f"{self.output_name}_anisotropic": tensor_2, } else: - stress = self.stress_block( + out_tensor = self.block( graph.edge_distance_vec, x_edge, graph.edge_index[1], data ) - output = {self.output_name: stress.reshape((-1, 3))} + output = {self.output_name: out_tensor.reshape((-1, 3))} return output diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml index 4c00fe6a2e..dbe3e4899d 100644 --- a/tests/core/models/test_configs/test_equiformerv2_hydra.yml +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -52,6 +52,9 @@ model: module: equiformer_v2_energy_head forces: module: equiformer_v2_force_head + stress: + module: rank2_symmetric_head + output_name: "stress" dataset: train: From 254ea9c6b642676f4a0385fb3add306eb3a0ead0 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 8 Aug 2024 12:20:18 -0700 Subject: [PATCH 03/10] small fixes --- src/fairchem/core/models/base.py | 2 +- src/fairchem/core/models/equiformer_v2/equiformer_v2.py | 4 ++-- .../core/models/equiformer_v2/prediction_heads/rank2.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 8ce8f3fcb1..4936c725fa 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -228,7 +228,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: @registry.register_model("hydra") -class HydraModel(nn.Module, GraphModelMixin): +class HydraModel(nn.Module): def __init__( self, backbone: dict, diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 37b5c72c07..b1c7214fab 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -836,7 +836,7 @@ def __init__(self, backbone): backbone.use_grid_mlp, backbone.use_sep_s2_act, ) - self.apply(backbone._init_weights) + self.apply(backbone.init_weights) self.apply(backbone._uniform_init_rad_func_linear_weights) def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): @@ -881,7 +881,7 @@ def __init__(self, backbone): backbone.use_sep_s2_act, alpha_drop=0.0, ) - self.apply(backbone._init_weights) + self.apply(backbone.init_weights) self.apply(backbone._uniform_init_rad_func_linear_weights) def forward(self, data: Batch, emb: dict[str, torch.Tensor]): diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index aaa39ef228..2c18e4afea 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -25,7 +25,7 @@ class Rank2Block(nn.Module): Args: emb_size (int): Size of edge embedding used to compute outer product num_layers (int): Number of layers of the MLP - edge_level (bool): Apply MLP to edges' outer product + edge_level (bool): If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling extensive (bool): Whether to sum or average the outer products """ @@ -324,7 +324,7 @@ def forward( graph.edge_distance_vec, x_edge, graph.edge_index[1], data ) - if self.extensive: # legacy, may be interesting to try + if self.block.extensive: # legacy, may be interesting to try tensor_0 = tensor_0 / self.avg_num_nodes tensor_2 = tensor_2 / self.avg_num_nodes From c56aa11ff85fb82417457ba2168cda6283859006 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 8 Aug 2024 15:44:55 -0700 Subject: [PATCH 04/10] keep hydra graphmixin --- src/fairchem/core/models/base.py | 2 +- .../core/models/equiformer_v2/prediction_heads/rank2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 4936c725fa..8ce8f3fcb1 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -228,7 +228,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: @registry.register_model("hydra") -class HydraModel(nn.Module): +class HydraModel(nn.Module, GraphModelMixin): def __init__( self, backbone: dict, diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index 2c18e4afea..22fb5e8b24 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -99,7 +99,7 @@ class Rank2DecompositionEdgeBlock(nn.Module): Args: emb_size (int): Size of edge embedding used to compute outer product num_layers (int): Number of layers of the MLP - edge_level (bool): Apply MLP to edges' outer product + edge_level (bool): If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling extensive (bool): Whether to sum or average the outer products """ @@ -244,7 +244,7 @@ def __init__( """ Args: backbone: Backbone model that the head is attached to - decompose: Wether to decompose the rank2 tensor into isotropic and anisotropic components + decompose: Whether to decompose the rank2 tensor into isotropic and anisotropic components use_source_target_embedding: Whether to use both source and target atom embeddings extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive). """ From d7a4460a76d42ccbe51127cbb84f12620d9bb1ca Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 8 Aug 2024 17:33:19 -0700 Subject: [PATCH 05/10] add rank2 head tests --- .../prediction_heads/__init__.py | 5 ++ .../equiformer_v2/prediction_heads/rank2.py | 1 + tests/core/models/test_rank2_head.py | 67 +++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 tests/core/models/test_rank2_head.py diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py index e69de29bb2..7542c0d139 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .rank2 import Rank2SymmetricTensorHead + +__all__ = ["Rank2SymmetricTensorHead"] diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index 22fb5e8b24..22941fbb8d 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -252,6 +252,7 @@ def __init__( self.output_name = output_name self.decompose = decompose self.use_source_target_embedding = use_source_target_embedding + self.avg_num_nodes = backbone.avg_num_nodes self.sphharm_norm = get_normalization_layer( backbone.norm_type, diff --git a/tests/core/models/test_rank2_head.py b/tests/core/models/test_rank2_head.py new file mode 100644 index 0000000000..d198e3403e --- /dev/null +++ b/tests/core/models/test_rank2_head.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from itertools import product + +import pytest +import torch +from ase.build import bulk + +from fairchem.core.common.utils import cg_change_mat, irreps_sum +from fairchem.core.datasets import data_list_collater +from fairchem.core.models.equiformer_v2.equiformer_v2 import EquiformerV2Backbone +from fairchem.core.models.equiformer_v2.prediction_heads import Rank2SymmetricTensorHead +from fairchem.core.preprocessing import AtomsToGraphs + + +def _reshape_tensor(out, batch_size=1): + tensor = torch.zeros((batch_size, irreps_sum(2)), requires_grad=False) + tensor[:, max(0, irreps_sum(1)) : irreps_sum(2)] = out.view(batch_size, -1) + tensor = torch.einsum("ba, cb->ca", cg_change_mat(2), tensor) + return tensor.view(3, 3) + + +@pytest.fixture(scope="session") +def batch(): + a2g = AtomsToGraphs(r_pbc=True) + return data_list_collater([a2g.convert(bulk("ZnFe", "wurtzite", a=2.0))]) + + +@pytest.mark.parametrize( + ("decompose", "edge_level_mlp", "use_source_target_embedding", "extensive"), + list(product((True, False), repeat=4)), +) +def test_rank2_head( + batch, decompose, edge_level_mlp, use_source_target_embedding, extensive +): + backbone = EquiformerV2Backbone( + num_layers=2, + sphere_channels=8, + attn_hidden_channels=8, + num_sphere_samples=8, + edge_channels=8, + ) + head = Rank2SymmetricTensorHead( + backbone=backbone, + output_name="out", + decompose=decompose, + edge_level_mlp=edge_level_mlp, + use_source_target_embedding=use_source_target_embedding, + extensive=extensive, + ) + + r2_out = head(batch, backbone(batch)) + + if decompose is True: + assert "out_isotropic" in r2_out + assert "out_anisotropic" in r2_out + # isotropic must be scalar + assert r2_out["out_isotropic"].shape[1] == 1 + tensor = _reshape_tensor(r2_out["out_isotropic"]) + # anisotropic must be traceless + assert torch.diagonal(tensor).sum().item() == pytest.approx(0.0, abs=1e-8) + else: + assert "out" in r2_out + tensor = r2_out["out"].view(3, 3) + + # all tensors must be symmetric + assert torch.allclose(tensor, tensor.transpose(0, 1)) From 70b9852d8e5367940549cd6fb3740dfcfb0b6857 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 15 Aug 2024 11:51:33 -0700 Subject: [PATCH 06/10] test fixes --- .../equiformer_v2/prediction_heads/rank2.py | 11 +- tests/core/conftest.py | 53 +++++ tests/core/e2e/conftest.py | 197 ++++++++++++++++++ tests/core/e2e/test_s2ef.py | 192 +---------------- tests/core/e2e/test_s2efs.py | 108 ++++++++++ .../test_configs/test_equiformerv2_hydra.yml | 5 +- tests/core/modules/conftest.py | 44 +--- 7 files changed, 371 insertions(+), 239 deletions(-) create mode 100644 tests/core/e2e/conftest.py create mode 100644 tests/core/e2e/test_s2efs.py diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index 22941fbb8d..2ee7be29d0 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -238,21 +238,26 @@ def __init__( output_name: str, decompose: bool = False, edge_level_mlp: bool = False, + num_mlp_layers: int = 2, use_source_target_embedding: bool = False, extensive: bool = False, + avg_num_nodes: int = 1.0, ): """ Args: backbone: Backbone model that the head is attached to decompose: Whether to decompose the rank2 tensor into isotropic and anisotropic components + edge_level_mlp: If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling + num_mlp_layers: number of MLP layers use_source_target_embedding: Whether to use both source and target atom embeddings extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive). + avg_num_nodes: Used only if extensive to divide prediction by avg num nodes. """ super().__init__() self.output_name = output_name self.decompose = decompose self.use_source_target_embedding = use_source_target_embedding - self.avg_num_nodes = backbone.avg_num_nodes + self.avg_num_nodes = avg_num_nodes self.sphharm_norm = get_normalization_layer( backbone.norm_type, @@ -270,14 +275,14 @@ def __init__( if decompose: self.block = Rank2DecompositionEdgeBlock( emb_size=r2_tensor_sphere_channels, - num_layers=2, + num_layers=num_mlp_layers, edge_level=edge_level_mlp, extensive=extensive, ) else: self.block = Rank2Block( emb_size=r2_tensor_sphere_channels, - num_layers=2, + num_layers=num_mlp_layers, edge_level=edge_level_mlp, extensive=extensive, ) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 46750f03b1..0e3606b306 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -13,12 +13,19 @@ if TYPE_CHECKING: from pathlib import Path +from itertools import product +from random import choice + import numpy as np import pytest import requests import torch +from pymatgen.core import Structure +from pymatgen.core.periodic_table import Element from syrupy.extensions.amber import AmberSnapshotExtension +from fairchem.core.datasets import AseDBDataset, LMDBDatabase + if TYPE_CHECKING: from syrupy.types import SerializableData @@ -172,3 +179,49 @@ def tutorial_dataset_path(tmp_path_factory) -> Path: tarfile.open(fileobj=response.raw, mode="r|gz").extractall(path=tmpdir) return tmpdir + + +@pytest.fixture(scope="session") +def dummy_element_refs(): + # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) + return np.concatenate( + [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] + ) + + +@pytest.fixture(scope="session") +def dummy_binary_dataset_path(tmpdir_factory, dummy_element_refs): + # a dummy dataset with binaries with energy that depends on composition only plus noise + all_binaries = list(product(list(Element), repeat=2)) + rng = np.random.default_rng(seed=0) + + tmpdir = tmpdir_factory.mktemp("dataset") + with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: + for _ in range(1000): + elements = choice(all_binaries) + structure = Structure.from_prototype("cscl", species=elements, a=2.0) + energy = ( + sum(e.average_ionic_radius for e in elements) + + 0.05 * rng.random() * dummy_element_refs.mean() + ) + atoms = structure.to_ase_atoms() + db.write( + atoms, + data={ + "energy": energy, + "forces": rng.random((2, 3)), + "stress": rng.random((3, 3)), + }, + ) + + return tmpdir / "dummy.aselmdb" + + +@pytest.fixture(scope="session") +def dummy_binary_dataset(dummy_binary_dataset_path): + return AseDBDataset( + config={ + "src": str(dummy_binary_dataset_path), + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + } + ) diff --git a/tests/core/e2e/conftest.py b/tests/core/e2e/conftest.py new file mode 100644 index 0000000000..1278579477 --- /dev/null +++ b/tests/core/e2e/conftest.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import collections.abc +import glob +import os +from pathlib import Path + +import pytest +import yaml +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from fairchem.core._cli import Runner +from fairchem.core.common.flags import flags +from fairchem.core.common.test_utils import ( + PGConfig, + init_env_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.common.utils import build_config + + +@pytest.fixture() +def configs(): + return { + "scn": Path("tests/core/models/test_configs/test_scn.yml"), + "escn": Path("tests/core/models/test_configs/test_escn.yml"), + "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), + "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), + "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), + "gemnet_dt_hydra": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" + ), + "gemnet_dt_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" + ), + "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), + "gemnet_oc_hydra": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" + ), + "gemnet_oc_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" + ), + "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), + "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), + "painn": Path("tests/core/models/test_configs/test_painn.yml"), + "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), + "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), + "equiformer_v2_hydra": Path( + "tests/core/models/test_configs/test_equiformerv2_hydra.yml" + ), + } + + +@pytest.fixture(scope="session") +def tutorial_train_src(tutorial_dataset_path): + return tutorial_dataset_path / "s2ef/train_100" + + +@pytest.fixture(scope="session") +def tutorial_val_src(tutorial_dataset_path): + return tutorial_dataset_path / "s2ef/val_20" + + +def get_tensorboard_log_files(logdir): + return glob.glob(f"{logdir}/tensorboard/*/events.out*") + + +def get_tensorboard_log_values(logdir): + tf_event_files = get_tensorboard_log_files(logdir) + assert len(tf_event_files) == 1 + tf_event_file = tf_event_files[0] + acc = EventAccumulator(tf_event_file) + acc.Reload() + return acc + + +def oc20_lmdb_train_and_val_from_paths( + train_src, val_src, test_src=None, otf_norms=False +): + datasets = {} + if train_src is not None: + datasets["train"] = { + "src": train_src, + "format": "lmdb", + "key_mapping": {"y": "energy", "force": "forces"}, + } + if otf_norms is True: + datasets["train"].update( + { + "transforms": { + "element_references": { + "fit": { + "targets": ["energy"], + "batch_size": 4, + "num_batches": 10, + "driver": "gelsd", + } + }, + "normalizer": { + "fit": { + "targets": {"energy": None, "forces": {"mean": 0.0}}, + "batch_size": 4, + "num_batches": 10, + } + }, + } + } + ) + else: + datasets["train"].update( + { + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0.0, "stdev": 2.887317180633545}, + } + } + } + ) + if val_src is not None: + datasets["val"] = {"src": val_src, "format": "lmdb"} + if test_src is not None: + datasets["test"] = {"src": test_src, "format": "lmdb"} + return datasets + + +def merge_dictionary(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = merge_dictionary(d.get(k, {}), v) + else: + d[k] = v + return d + + +def _run_main( + rundir, + input_yaml, + update_dict_with=None, + update_run_args_with=None, + save_checkpoint_to=None, + save_predictions_to=None, + world_size=0, +): + config_yaml = Path(rundir) / "train_and_val_on_val.yml" + + with open(input_yaml) as yaml_file: + yaml_config = yaml.safe_load(yaml_file) + if update_dict_with is not None: + yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" + with open(str(config_yaml), "w") as yaml_file: + yaml.dump(yaml_config, yaml_file) + run_args = { + "run_dir": rundir, + "logdir": f"{rundir}/logs", + "config_yml": config_yaml, + } + if update_run_args_with is not None: + run_args.update(update_run_args_with) + + # run + parser = flags.get_parser() + args, override_args = parser.parse_known_args( + ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"] + ) + for arg_name, arg_value in run_args.items(): + setattr(args, arg_name, arg_value) + config = build_config(args, override_args) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + spawn_multi_process( + pg_config, + Runner(distributed=True), + init_env_rank_and_launch_test, + config, + ) + else: + Runner()(config) + + if save_checkpoint_to is not None: + checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") + assert len(checkpoints) == 1 + os.rename(checkpoints[0], save_checkpoint_to) + if save_predictions_to is not None: + predictions_filenames = glob.glob(f"{rundir}/results/*/s2ef_predictions.npz") + assert len(predictions_filenames) == 1 + os.rename(predictions_filenames[0], save_predictions_to) + return get_tensorboard_log_values( + f"{rundir}/logs", + ) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 6c773d32eb..e398b983af 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections.abc import glob import os import tempfile @@ -9,200 +8,13 @@ import numpy as np import numpy.testing as npt import pytest -import yaml -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from conftest import _run_main, oc20_lmdb_train_and_val_from_paths -from fairchem.core._cli import Runner -from fairchem.core.common.flags import flags -from fairchem.core.common.test_utils import ( - PGConfig, - init_env_rank_and_launch_test, - spawn_multi_process, -) -from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.common.utils import setup_logging from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes setup_logging() - -@pytest.fixture() -def configs(): - return { - "scn": Path("tests/core/models/test_configs/test_scn.yml"), - "escn": Path("tests/core/models/test_configs/test_escn.yml"), - "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), - "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), - "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), - "gemnet_dt_hydra": Path( - "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" - ), - "gemnet_dt_hydra_grad": Path( - "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" - ), - "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), - "gemnet_oc_hydra": Path( - "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" - ), - "gemnet_oc_hydra_grad": Path( - "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" - ), - "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), - "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), - "painn": Path("tests/core/models/test_configs/test_painn.yml"), - "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), - "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), - "equiformer_v2_hydra": Path( - "tests/core/models/test_configs/test_equiformerv2_hydra.yml" - ), - } - - -@pytest.fixture() -def tutorial_train_src(tutorial_dataset_path): - return tutorial_dataset_path / "s2ef/train_100" - - -@pytest.fixture() -def tutorial_val_src(tutorial_dataset_path): - return tutorial_dataset_path / "s2ef/val_20" - - -def oc20_lmdb_train_and_val_from_paths( - train_src, val_src, test_src=None, otf_norms=False -): - datasets = {} - if train_src is not None: - datasets["train"] = { - "src": train_src, - "format": "lmdb", - "key_mapping": {"y": "energy", "force": "forces"}, - } - if otf_norms is True: - datasets["train"].update( - { - "transforms": { - "element_references": { - "fit": { - "targets": ["energy"], - "batch_size": 4, - "num_batches": 10, - "driver": "gelsd", - } - }, - "normalizer": { - "fit": { - "targets": {"energy": None, "forces": {"mean": 0.0}}, - "batch_size": 4, - "num_batches": 10, - } - }, - } - } - ) - else: - datasets["train"].update( - { - "transforms": { - "normalizer": { - "energy": { - "mean": -0.7554450631141663, - "stdev": 2.887317180633545, - }, - "forces": {"mean": 0.0, "stdev": 2.887317180633545}, - } - } - } - ) - if val_src is not None: - datasets["val"] = {"src": val_src, "format": "lmdb"} - if test_src is not None: - datasets["test"] = {"src": test_src, "format": "lmdb"} - return datasets - - -def get_tensorboard_log_files(logdir): - return glob.glob(f"{logdir}/tensorboard/*/events.out*") - - -def get_tensorboard_log_values(logdir): - tf_event_files = get_tensorboard_log_files(logdir) - assert len(tf_event_files) == 1 - tf_event_file = tf_event_files[0] - acc = EventAccumulator(tf_event_file) - acc.Reload() - return acc - - -def merge_dictionary(d, u): - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = merge_dictionary(d.get(k, {}), v) - else: - d[k] = v - return d - - -def _run_main( - rundir, - input_yaml, - update_dict_with=None, - update_run_args_with=None, - save_checkpoint_to=None, - save_predictions_to=None, - world_size=0, -): - config_yaml = Path(rundir) / "train_and_val_on_val.yml" - - with open(input_yaml) as yaml_file: - yaml_config = yaml.safe_load(yaml_file) - if update_dict_with is not None: - yaml_config = merge_dictionary(yaml_config, update_dict_with) - yaml_config["backend"] = "gloo" - with open(str(config_yaml), "w") as yaml_file: - yaml.dump(yaml_config, yaml_file) - run_args = { - "run_dir": rundir, - "logdir": f"{rundir}/logs", - "config_yml": config_yaml, - } - if update_run_args_with is not None: - run_args.update(update_run_args_with) - - # run - parser = flags.get_parser() - args, override_args = parser.parse_known_args( - ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"] - ) - for arg_name, arg_value in run_args.items(): - setattr(args, arg_name, arg_value) - config = build_config(args, override_args) - - if world_size > 0: - pg_config = PGConfig( - backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False - ) - spawn_multi_process( - pg_config, - Runner(distributed=True), - init_env_rank_and_launch_test, - config, - ) - else: - Runner()(config) - - if save_checkpoint_to is not None: - checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") - assert len(checkpoints) == 1 - os.rename(checkpoints[0], save_checkpoint_to) - if save_predictions_to is not None: - predictions_filenames = glob.glob(f"{rundir}/results/*/s2ef_predictions.npz") - assert len(predictions_filenames) == 1 - os.rename(predictions_filenames[0], save_predictions_to) - return get_tensorboard_log_values( - f"{rundir}/logs", - ) - - """ These tests are intended to be as quick as possible and test only that the network is runnable and outputs training+validation to tensorboard output These should catch errors such as shape mismatches or otherways to code wise break a network diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py new file mode 100644 index 0000000000..5032f50958 --- /dev/null +++ b/tests/core/e2e/test_s2efs.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +from conftest import _run_main + + +# TODO add GemNet! +@pytest.mark.parametrize( + ("model_name", "ddp"), + [ + ("equiformer_v2_hydra", False), + ("escn_hydra", False), + ("equiformer_v2_hydra", True), + ("escn_hydra", True), + ], +) +def test_smoke_s2efs_predict( + model_name, ddp, configs, dummy_binary_dataset_path, tmpdir +): + # train an s2ef model just to have one + input_yaml = configs[model_name] + train_rundir = tmpdir / "train" + train_rundir.mkdir() + checkpoint_path = str(train_rundir / "checkpoint.pt") + training_predictions_filename = str(train_rundir / "train_predictions.npz") + + updates = { + "task": {"strict_load": False}, + "model": { + "backbone": {"max_num_elements": 118}, + "heads": { + "stress": { + "module": "rank2_symmetric_head", + "output_name": "stress", + "use_source_target_embedding": True, + } + }, + }, + "loss_functions": [ + {"energy": {"fn": "mae", "coefficient": 2}}, + {"forces": {"fn": "l2mae", "coefficient": 100}}, + {"stress": {"fn": "mae", "coefficient": 100}}, + ], + "outputs": {"stress": {"level": "system", "irrep_dim": 2}}, + "evaluation_metrics": {"metrics": {"stress": "mae"}}, + "dataset": { + "train": { + "src": str(dummy_binary_dataset_path), + "format": "ase_db", + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + }, + "val": { + "src": str(dummy_binary_dataset_path), + "format": "ase_db", + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + }, + }, + } + + acc = _run_main( + rundir=str(train_rundir), + input_yaml=input_yaml, + update_dict_with={ + "optim": { + "max_epochs": 2, + "eval_every": 4, + "batch_size": 5, + "num_workers": 0 if ddp else 2, + }, + **updates, + }, + save_checkpoint_to=checkpoint_path, + save_predictions_to=training_predictions_filename, + world_size=1 if ddp else 0, + ) + assert "train/energy_mae" in acc.Tags()["scalars"] + assert "val/energy_mae" in acc.Tags()["scalars"] + + # now load a checkpoint with an added stress head + # second load the checkpoint and predict + predictions_rundir = Path(tmpdir) / "predict" + predictions_rundir.mkdir() + predictions_filename = str(predictions_rundir / "predictions.npz") + _run_main( + rundir=str(predictions_rundir), + input_yaml=input_yaml, + update_dict_with={ + "task": {"strict_load": False}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, + **updates, + }, + update_run_args_with={ + "mode": "predict", + "checkpoint": checkpoint_path, + }, + save_predictions_to=predictions_filename, + ) + predictions = np.load(training_predictions_filename) + + for output in input_yaml["outputs"]: + assert output in predictions + + assert predictions["energy"].shape == (20,) + assert predictions["forces"].shape == (20, 3) + assert predictions["stress"].shape == (20, 9) diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml index ece0543f16..1852799f5e 100644 --- a/tests/core/models/test_configs/test_equiformerv2_hydra.yml +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -29,7 +29,7 @@ evaluation_metrics: misc: - energy_forces_within_threshold primary_metric: forces_mae - + logger: name: tensorboard @@ -83,9 +83,6 @@ model: module: equiformer_v2_energy_head forces: module: equiformer_v2_force_head - stress: - module: rank2_symmetric_head - output_name: "stress" optim: batch_size: 5 diff --git a/tests/core/modules/conftest.py b/tests/core/modules/conftest.py index 1b1e4ab7e6..0a210639de 100644 --- a/tests/core/modules/conftest.py +++ b/tests/core/modules/conftest.py @@ -1,48 +1,8 @@ -from itertools import product -from random import choice -import pytest -import numpy as np -from pymatgen.core.periodic_table import Element -from pymatgen.core import Structure - -from fairchem.core.datasets import LMDBDatabase, AseDBDataset - +from __future__ import annotations -@pytest.fixture(scope="session") -def dummy_element_refs(): - # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) - return np.concatenate( - [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] - ) +import pytest @pytest.fixture(scope="session") def max_num_elements(dummy_element_refs): return len(dummy_element_refs) - 1 - - -@pytest.fixture(scope="session") -def dummy_binary_dataset(tmpdir_factory, dummy_element_refs): - # a dummy dataset with binaries with energy that depends on composition only plus noise - all_binaries = list(product(list(Element), repeat=2)) - rng = np.random.default_rng(seed=0) - - tmpdir = tmpdir_factory.mktemp("dataset") - with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: - for _ in range(1000): - elements = choice(all_binaries) - structure = Structure.from_prototype("cscl", species=elements, a=2.0) - energy = ( - sum(e.average_ionic_radius for e in elements) - + 0.05 * rng.random() * dummy_element_refs.mean() - ) - atoms = structure.to_ase_atoms() - db.write(atoms, data={"energy": energy, "forces": rng.random((2, 3))}) - - dataset = AseDBDataset( - config={ - "src": str(tmpdir / "dummy.aselmdb"), - "a2g_args": {"r_data_keys": ["energy", "forces"]}, - } - ) - return dataset From 37cbb270ee260d5dd47262657537f21260243758 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 16 Aug 2024 00:32:47 +0000 Subject: [PATCH 07/10] fix tests; move init_weight out of equiformer; add amp property to heads+hydra --- src/fairchem/core/models/base.py | 10 +++- .../models/equiformer_v2/equiformer_v2.py | 56 +++++++++---------- .../equiformer_v2/prediction_heads/rank2.py | 8 ++- src/fairchem/core/models/escn/escn.py | 7 ++- src/fairchem/core/trainers/base_trainer.py | 28 ++++++---- src/fairchem/core/trainers/ocp_trainer.py | 4 +- tests/core/e2e/conftest.py | 2 +- tests/core/e2e/test_s2efs.py | 20 +++++-- 8 files changed, 82 insertions(+), 53 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 8ce8f3fcb1..f8144887d7 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -188,6 +188,10 @@ def no_weight_decay(self) -> list: class HeadInterface(metaclass=ABCMeta): + @property + def use_amp_in_head(self): + return False + @abstractmethod def forward( self, data: Batch, emb: dict[str, torch.Tensor] @@ -269,6 +273,10 @@ def forward(self, data: Batch): # Predict all output properties for all structures in the batch for now. out = {} for k in self.output_heads: - out.update(self.output_heads[k](data, emb)) + with torch.autocast( + device_type=self.device, enabled=self.output_heads.use_amp + ): + print("USE AMP", self.output_heads.use_amp) + out.update(self.output_heads[k](data, emb)) return out diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index b1c7214fab..7b89025394 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -3,6 +3,7 @@ import contextlib import logging import math +from functools import partial import torch import torch.nn as nn @@ -54,6 +55,28 @@ _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 +def eqv2_init_weights(m, weight_init): + if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + if weight_init == "normal": + std = 1 / math.sqrt(m.in_features) + torch.nn.init.normal_(m.weight, 0, std) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + elif isinstance(m, RadialFunction): + m.apply(eqv2_uniform_init_linear_weights) + + +def eqv2_uniform_init_linear_weights(m): + if isinstance(m, torch.nn.Linear): + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + std = 1 / math.sqrt(m.in_features) + torch.nn.init.uniform_(m.weight, -std, std) + + @registry.register_model("equiformer_v2") class EquiformerV2(nn.Module, GraphModelMixin): """ @@ -398,8 +421,8 @@ def __init__( requires_grad=False, ) - self.apply(self.init_weights) - self.apply(self._uniform_init_rad_func_linear_weights) + self.apply(partial(eqv2_init_weights, weight_init=self.weight_init)) + # self.apply(eqv2_uniform_init_rad_func_linear_weights) def _init_gp_partitions( self, @@ -628,29 +651,6 @@ def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): def num_params(self): return sum(p.numel() for p in self.parameters()) - def init_weights(self, m): - if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - if self.weight_init == "normal": - std = 1 / math.sqrt(m.in_features) - torch.nn.init.normal_(m.weight, 0, std) - - elif isinstance(m, torch.nn.LayerNorm): - torch.nn.init.constant_(m.bias, 0) - torch.nn.init.constant_(m.weight, 1.0) - - def _uniform_init_rad_func_linear_weights(self, m): - if isinstance(m, RadialFunction): - m.apply(self._uniform_init_linear_weights) - - def _uniform_init_linear_weights(self, m): - if isinstance(m, torch.nn.Linear): - if m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - std = 1 / math.sqrt(m.in_features) - torch.nn.init.uniform_(m.weight, -std, std) - @torch.jit.ignore def no_weight_decay(self) -> set: no_wd_list = [] @@ -836,8 +836,7 @@ def __init__(self, backbone): backbone.use_grid_mlp, backbone.use_sep_s2_act, ) - self.apply(backbone.init_weights) - self.apply(backbone._uniform_init_rad_func_linear_weights) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): node_energy = self.energy_block(emb["node_embedding"]) @@ -881,8 +880,7 @@ def __init__(self, backbone): backbone.use_sep_s2_act, alpha_drop=0.0, ) - self.apply(backbone.init_weights) - self.apply(backbone._uniform_init_rad_func_linear_weights) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) def forward(self, data: Batch, emb: dict[str, torch.Tensor]): forces = self.force_block( diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index 2ee7be29d0..2bbf42eaa0 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -7,6 +7,8 @@ from __future__ import annotations +from functools import partial + import torch from e3nn import o3 from torch import nn @@ -14,6 +16,7 @@ from fairchem.core.common.registry import registry from fairchem.core.models.base import BackboneInterface, HeadInterface +from fairchem.core.models.equiformer_v2.equiformer_v2 import eqv2_init_weights from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer @@ -242,6 +245,7 @@ def __init__( use_source_target_embedding: bool = False, extensive: bool = False, avg_num_nodes: int = 1.0, + default_norm_type: str = "layer_norm_sh", ): """ Args: @@ -260,7 +264,7 @@ def __init__( self.avg_num_nodes = avg_num_nodes self.sphharm_norm = get_normalization_layer( - backbone.norm_type, + getattr(backbone, "norm_type", default_norm_type), lmax=max(backbone.lmax_list), num_channels=1, ) @@ -288,7 +292,7 @@ def __init__( ) # initialize weights - self.block.apply(backbone.init_weights) + self.block.apply(partial(eqv2_init_weights, weight_init="uniform")) def forward( self, data: dict[str, torch.Tensor] | torch.Tensor, emb: dict[str, torch.Tensor] diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 62a582b4cc..47ad718ff2 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -523,7 +523,12 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: x_pt = x_pt.view(-1, self.sphere_channels_all) - return {"sphere_values": x_pt, "sphere_points": self.sphere_points} + return { + "sphere_values": x_pt, + "sphere_points": self.sphere_points, + "node_embedding": x, + "graph": graph, + } @registry.register_model("escn_energy_head") diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 40c7e65de6..135426d42d 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -422,9 +422,11 @@ def load_references_and_normalizers(self): elementref_config, dataset=self.train_dataset, seed=self.config["cmd"]["seed"], - checkpoint_dir=self.config["cmd"]["checkpoint_dir"] - if not self.is_debug - else None, + checkpoint_dir=( + self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None + ), ) if norms_config is not None: @@ -432,9 +434,11 @@ def load_references_and_normalizers(self): norms_config, dataset=self.train_dataset, seed=self.config["cmd"]["seed"], - checkpoint_dir=self.config["cmd"]["checkpoint_dir"] - if not self.is_debug - else None, + checkpoint_dir=( + self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None + ), element_references=elementrefs, ) @@ -483,15 +487,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("train_on_free_atoms", True) + self.config["outputs"][target_name].get( + "train_on_free_atoms", True + ) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("eval_on_free_atoms", True) + self.config["outputs"][target_name].get( + "eval_on_free_atoms", True + ) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 26269c6da4..662341bdc5 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -655,7 +655,9 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[:-1] # np.split does not need last idx, assumes n-1:end + )[ + :-1 + ] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/e2e/conftest.py b/tests/core/e2e/conftest.py index 1278579477..817c2bf9f6 100644 --- a/tests/core/e2e/conftest.py +++ b/tests/core/e2e/conftest.py @@ -183,7 +183,7 @@ def _run_main( ) else: Runner()(config) - + f = glob.glob(f"{rundir}/checkpoints/*/*") if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") assert len(checkpoints) == 1 diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py index 5032f50958..a4c32f7004 100644 --- a/tests/core/e2e/test_s2efs.py +++ b/tests/core/e2e/test_s2efs.py @@ -30,7 +30,7 @@ def test_smoke_s2efs_predict( updates = { "task": {"strict_load": False}, "model": { - "backbone": {"max_num_elements": 118}, + "backbone": {"max_num_elements": 118 + 1}, "heads": { "stress": { "module": "rank2_symmetric_head", @@ -45,17 +45,25 @@ def test_smoke_s2efs_predict( {"stress": {"fn": "mae", "coefficient": 100}}, ], "outputs": {"stress": {"level": "system", "irrep_dim": 2}}, - "evaluation_metrics": {"metrics": {"stress": "mae"}}, + "evaluation_metrics": {"metrics": {"stress": ["mae"]}}, "dataset": { "train": { "src": str(dummy_binary_dataset_path), "format": "ase_db", "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + "sample_n": 20, }, "val": { "src": str(dummy_binary_dataset_path), "format": "ase_db", "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + "sample_n": 5, + }, + "test": { + "src": str(dummy_binary_dataset_path), + "format": "ase_db", + "a2g_args": {"r_data_keys": ["energy", "forces", "stress"]}, + "sample_n": 5, }, }, } @@ -100,9 +108,9 @@ def test_smoke_s2efs_predict( ) predictions = np.load(training_predictions_filename) - for output in input_yaml["outputs"]: + for output in updates["outputs"]: assert output in predictions - assert predictions["energy"].shape == (20,) - assert predictions["forces"].shape == (20, 3) - assert predictions["stress"].shape == (20, 9) + assert predictions["energy"].shape == (5, 1) + assert predictions["forces"].shape == (10, 3) + assert predictions["stress"].shape == (5, 9) From 92a8039554973ea5a3a422c4e3953761497dcbba Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 16 Aug 2024 00:41:36 +0000 Subject: [PATCH 08/10] add amp to heads and hydra --- src/fairchem/core/models/base.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index f8144887d7..3e1acd12f6 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -189,7 +189,7 @@ def no_weight_decay(self) -> list: class HeadInterface(metaclass=ABCMeta): @property - def use_amp_in_head(self): + def use_amp(self): return False @abstractmethod @@ -241,6 +241,7 @@ def __init__( ): super().__init__() self.otf_graph = otf_graph + self.device = "cpu" backbone_model_name = backbone.pop("model") self.backbone: BackboneInterface = registry.get_model_class( @@ -268,15 +269,19 @@ def __init__( self.output_heads = torch.nn.ModuleDict(self.output_heads) + def to(self, *args, **kwargs): + if "device" in kwargs: + self.device = kwargs["device"] + return super().to(*args, **kwargs) + def forward(self, data: Batch): emb = self.backbone(data) # Predict all output properties for all structures in the batch for now. out = {} for k in self.output_heads: with torch.autocast( - device_type=self.device, enabled=self.output_heads.use_amp + device_type=self.device, enabled=self.output_heads[k].use_amp ): - print("USE AMP", self.output_heads.use_amp) out.update(self.output_heads[k](data, emb)) return out From c49f6332c5e58bb114b680e77b842078812bcc55 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 16 Aug 2024 20:23:25 +0000 Subject: [PATCH 09/10] fix import --- tests/core/e2e/test_s2efs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py index a4c32f7004..94b0862edb 100644 --- a/tests/core/e2e/test_s2efs.py +++ b/tests/core/e2e/test_s2efs.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from conftest import _run_main +from test_e2e_commons import _run_main # TODO add GemNet! From 46bee2f0fcdbbd932ba56434604debb783753320 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 16 Aug 2024 22:05:07 +0000 Subject: [PATCH 10/10] update snapshot; fix test seed and change tolerance --- src/fairchem/core/models/equiformer_v2/equiformer_v2.py | 1 - tests/core/models/__snapshots__/test_equiformer_v2.ambr | 4 ++-- tests/core/models/test_rank2_head.py | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 4d7da1d984..b78f435978 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -424,7 +424,6 @@ def __init__( ) self.apply(partial(eqv2_init_weights, weight_init=self.weight_init)) - # self.apply(eqv2_uniform_init_rad_func_linear_weights) def _init_gp_partitions( self, diff --git a/tests/core/models/__snapshots__/test_equiformer_v2.ambr b/tests/core/models/__snapshots__/test_equiformer_v2.ambr index 5ddf7f2bea..03be8ebdac 100644 --- a/tests/core/models/__snapshots__/test_equiformer_v2.ambr +++ b/tests/core/models/__snapshots__/test_equiformer_v2.ambr @@ -56,7 +56,7 @@ # --- # name: TestEquiformerV2.test_gp.1 Approx( - array([-0.03269595], dtype=float32), + array([0.12408739], dtype=float32), rtol=0.001, atol=0.001 ) @@ -69,7 +69,7 @@ # --- # name: TestEquiformerV2.test_gp.3 Approx( - array([ 0.00208857, -0.00017979, -0.0028318 ], dtype=float32), + array([ 1.4928661e-03, -7.4134863e-05, 2.9909245e-03], dtype=float32), rtol=0.001, atol=0.001 ) diff --git a/tests/core/models/test_rank2_head.py b/tests/core/models/test_rank2_head.py index d198e3403e..c00667806e 100644 --- a/tests/core/models/test_rank2_head.py +++ b/tests/core/models/test_rank2_head.py @@ -33,6 +33,7 @@ def batch(): def test_rank2_head( batch, decompose, edge_level_mlp, use_source_target_embedding, extensive ): + torch.manual_seed(100) # fix network initialization backbone = EquiformerV2Backbone( num_layers=2, sphere_channels=8, @@ -58,7 +59,7 @@ def test_rank2_head( assert r2_out["out_isotropic"].shape[1] == 1 tensor = _reshape_tensor(r2_out["out_isotropic"]) # anisotropic must be traceless - assert torch.diagonal(tensor).sum().item() == pytest.approx(0.0, abs=1e-8) + assert torch.diagonal(tensor).sum().item() == pytest.approx(0.0, abs=2e-8) else: assert "out" in r2_out tensor = r2_out["out"].view(3, 3)