diff --git a/python/aitemplate/testing/test_utils.py b/python/aitemplate/testing/test_utils.py index 88fb8b999..07e43a6cd 100644 --- a/python/aitemplate/testing/test_utils.py +++ b/python/aitemplate/testing/test_utils.py @@ -30,6 +30,7 @@ from aitemplate.testing.detect_target import detect_target from aitemplate.utils.graph_utils import get_sorted_ops from aitemplate.utils.torch_utils import string_to_torch_dtype +from torch import nn class TestEnv(Enum): @@ -298,3 +299,18 @@ def get_attn_mask_per_causal_type( else: raise NotImplementedError(f"Unsupported {causal_type=}!") return invalid_attn_mask + + +def init_random_weights(m): + if hasattr(m, "weight"): + nn.init.uniform_(m.weight) + elif ( + type(m) == nn.Sequential + or type(m) == nn.ModuleList + or type(m) == nn.SiLU + or type(m) == nn.Dropout + or type(m) == nn.Identity + ): + pass + else: + print("Passed root module: " + str(type(m)))