From a20384e73de4c1a96a8b8656613e9baa09b88e40 Mon Sep 17 00:00:00 2001 From: Colin Chan Date: Tue, 27 Jun 2023 11:58:10 -0700 Subject: [PATCH] Add init_random_weights test util (#800) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/800 Initialize random weights for AIT constants during model compilation to prevent identical weights being compared when testing accuracy of PT module vs AIT module. Reviewed By: henryhu6 Differential Revision: D47031569 fbshipit-source-id: f063a8b13d3a530f7c667ce4b2259f9177bdd4fa --- python/aitemplate/testing/test_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/aitemplate/testing/test_utils.py b/python/aitemplate/testing/test_utils.py index 88fb8b999..07e43a6cd 100644 --- a/python/aitemplate/testing/test_utils.py +++ b/python/aitemplate/testing/test_utils.py @@ -30,6 +30,7 @@ from aitemplate.testing.detect_target import detect_target from aitemplate.utils.graph_utils import get_sorted_ops from aitemplate.utils.torch_utils import string_to_torch_dtype +from torch import nn class TestEnv(Enum): @@ -298,3 +299,18 @@ def get_attn_mask_per_causal_type( else: raise NotImplementedError(f"Unsupported {causal_type=}!") return invalid_attn_mask + + +def init_random_weights(m): + if hasattr(m, "weight"): + nn.init.uniform_(m.weight) + elif ( + type(m) == nn.Sequential + or type(m) == nn.ModuleList + or type(m) == nn.SiLU + or type(m) == nn.Dropout + or type(m) == nn.Identity + ): + pass + else: + print("Passed root module: " + str(type(m)))