From ef2a4bc0b2fcbc637a0c27d852ac4c208f5a77c8 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:09:49 -0700 Subject: [PATCH] Activation checkpoint equiformersv2 (#811) * add back activation checkpoint * typo * add act check test * lint * rename test --- .../models/equiformer_v2/equiformer_v2.py | 56 ++++++++++++++----- tests/core/models/test_equiformer_v2.py | 46 +++++++++++++++ 2 files changed, 87 insertions(+), 15 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index bda8181c5..72c519cf3 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -157,6 +157,7 @@ def __init__( avg_degree: float | None = None, use_energy_lin_ref: bool | None = False, load_energy_lin_ref: bool | None = False, + activation_checkpoint: bool | None = False, ): if mmax_list is None: mmax_list = [2] @@ -170,6 +171,7 @@ def __init__( logging.error("You need to install e3nn==0.4.4 to use EquiformerV2.") raise ImportError + self.activation_checkpoint = activation_checkpoint self.use_pbc = use_pbc self.use_pbc_single = use_pbc_single self.regress_forces = regress_forces @@ -803,14 +805,26 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ############################################################### for i in range(self.num_layers): - x = self.blocks[i]( - x, # SO3_Embedding - graph.atomic_numbers_full, - graph.edge_distance, - graph.edge_index, - batch=data_batch, # for GraphDropPath - node_offset=graph.node_offset, - ) + if self.activation_checkpoint: + x = torch.utils.checkpoint.checkpoint( + self.blocks[i], + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + data_batch, # for GraphDropPath + graph.node_offset, + use_reentrant=not self.training, + ) + else: + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) # Final layer norm x.embedding = self.norm(x.embedding) @@ -858,6 +872,7 @@ class EquiformerV2ForceHead(nn.Module, HeadInterface): def __init__(self, backbone): super().__init__() + self.activation_checkpoint = backbone.activation_checkpoint self.force_block = SO2EquivariantGraphAttention( backbone.sphere_channels, backbone.attn_hidden_channels, @@ -885,13 +900,24 @@ def __init__(self, backbone): self.apply(backbone._uniform_init_rad_func_linear_weights) def forward(self, data: Batch, emb: dict[str, torch.Tensor]): - forces = self.force_block( - emb["node_embedding"], - emb["graph"].atomic_numbers_full, - emb["graph"].edge_distance, - emb["graph"].edge_index, - node_offset=emb["graph"].node_offset, - ) + if self.activation_checkpoint: + forces = torch.utils.checkpoint.checkpoint( + self.force_block, + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + emb["graph"].node_offset, + use_reentrant=not self.training, + ) + else: + forces = self.force_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) forces = forces.embedding.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() if gp_utils.initialized(): diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 3194dd2df..54d58db1c 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -10,10 +10,12 @@ import copy import io import os +from pathlib import Path import pytest import requests import torch +import yaml from ase.io import read from torch.nn.parallel.distributed import DistributedDataParallel @@ -230,3 +232,47 @@ def sign(x): embedding._l_primary(c) lp = embedding.embedding.clone() (test_matrix_lp == lp).all() + + +def _load_hydra_model(): + torch.manual_seed(4) + with open(Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml")) as yaml_file: + yaml_config = yaml.safe_load(yaml_file) + model = registry.get_model_class("hydra")(yaml_config["model"]["backbone"],yaml_config["model"]["heads"]) + model.backbone.num_layers = 1 + return model + +def test_eqv2_hydra_activation_checkpoint(): + atoms = read( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), + index=0, + format="json", + ) + a2g = AtomsToGraphs( + max_neigh=200, + radius=6, + r_edges=False, + r_fixed=True, + ) + data_list = a2g.convert_all([atoms]) + inputs = data_list_collater(data_list) + no_ac_model = _load_hydra_model() + ac_model = _load_hydra_model() + ac_model.backbone.activation_checkpoint=True + + # to do this test we need both models to have the exact same state and the only + # way to do this is save the rng state and reset it after stepping the first model + start_rng_state = torch.random.get_rng_state() + outputs_no_ac = no_ac_model(inputs) + torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum()) + + # reset the rng state to the beginning + torch.random.set_rng_state(start_rng_state) + outptuts_ac = ac_model(inputs) + torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum()) + + # assert all the gradients are identical between the model with checkpointing and no checkpointing + ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None} + no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None} + for name in no_ac_model_grad_dict: + assert torch.allclose(no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4)