diff --git a/algorithmic_efficiency/init_utils.py b/algorithmic_efficiency/init_utils.py index 66ed041ce..185480cc7 100644 --- a/algorithmic_efficiency/init_utils.py +++ b/algorithmic_efficiency/init_utils.py @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None: # Perform lecun_normal initialization. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) std = math.sqrt(1. / fan_in) / .87962566103423978 - nn.init.trunc_normal_(module.weight, std=std) + nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std) if module.bias is not None: nn.init.constant_(module.bias, 0.)