Skip to content

Commit

Permalink
fix pytorch_default_init()
Browse files Browse the repository at this point in the history
torch.nn.init.trunc_normal_() defaults to truncation at (a, b),
not (a * std, b * std).
  • Loading branch information
EIFY committed Nov 27, 2024
1 parent 86d2a0d commit 579a485
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None:
# Perform lecun_normal initialization.
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
std = math.sqrt(1. / fan_in) / .87962566103423978
nn.init.trunc_normal_(module.weight, std=std)
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.)

0 comments on commit 579a485

Please sign in to comment.