Skip to content

Commit

Permalink
Add init_random_weights test util (#800)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #800

Initialize random weights for AIT constants during model compilation to prevent identical weights being compared when testing accuracy of PT module vs AIT module.

Reviewed By: henryhu6

Differential Revision: D47031569

fbshipit-source-id: f063a8b13d3a530f7c667ce4b2259f9177bdd4fa
  • Loading branch information
Colin Chan authored and facebook-github-bot committed Jun 27, 2023
1 parent dccb361 commit a20384e
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/aitemplate/testing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from aitemplate.testing.detect_target import detect_target
from aitemplate.utils.graph_utils import get_sorted_ops
from aitemplate.utils.torch_utils import string_to_torch_dtype
from torch import nn


class TestEnv(Enum):
Expand Down Expand Up @@ -298,3 +299,18 @@ def get_attn_mask_per_causal_type(
else:
raise NotImplementedError(f"Unsupported {causal_type=}!")
return invalid_attn_mask


def init_random_weights(m):
if hasattr(m, "weight"):
nn.init.uniform_(m.weight)
elif (
type(m) == nn.Sequential
or type(m) == nn.ModuleList
or type(m) == nn.SiLU
or type(m) == nn.Dropout
or type(m) == nn.Identity
):
pass
else:
print("Passed root module: " + str(type(m)))

0 comments on commit a20384e

Please sign in to comment.