Skip to content

Commit

Permalink
Merge pull request #819 from EIFY/torch-init-fix
Browse files Browse the repository at this point in the history
fix pytorch_default_init()
  • Loading branch information
priyakasimbeg authored Dec 12, 2024
2 parents 90959e1 + 579a485 commit fe90379
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None:
# Perform lecun_normal initialization.
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
std = math.sqrt(1. / fan_in) / .87962566103423978
nn.init.trunc_normal_(module.weight, std=std)
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.)

0 comments on commit fe90379

Please sign in to comment.