diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 72c519cf36..b303945602 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -637,6 +637,8 @@ def _init_weights(self, m): if self.weight_init == "normal": std = 1 / math.sqrt(m.in_features) torch.nn.init.normal_(m.weight, 0, std) + elif self.weight_init == "uniform": + self._uniform_init_linear_weights(m) elif isinstance(m, torch.nn.LayerNorm): torch.nn.init.constant_(m.bias, 0) @@ -647,7 +649,7 @@ def _uniform_init_rad_func_linear_weights(self, m): m.apply(self._uniform_init_linear_weights) def _uniform_init_linear_weights(self, m): - if isinstance(m, torch.nn.Linear): + if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): if m.bias is not None: torch.nn.init.constant_(m.bias, 0) std = 1 / math.sqrt(m.in_features) diff --git a/tests/core/models/__snapshots__/test_equiformer_v2.ambr b/tests/core/models/__snapshots__/test_equiformer_v2.ambr index 9c595426ab..5ddf7f2bea 100644 --- a/tests/core/models/__snapshots__/test_equiformer_v2.ambr +++ b/tests/core/models/__snapshots__/test_equiformer_v2.ambr @@ -56,7 +56,7 @@ # --- # name: TestEquiformerV2.test_gp.1 Approx( - array([0.12408741], dtype=float32), + array([-0.03269595], dtype=float32), rtol=0.001, atol=0.001 ) @@ -69,7 +69,7 @@ # --- # name: TestEquiformerV2.test_gp.3 Approx( - array([ 1.4928658e-03, -7.4134972e-05, 2.9909210e-03], dtype=float32), + array([ 0.00208857, -0.00017979, -0.0028318 ], dtype=float32), rtol=0.001, atol=0.001 )