diff --git a/python/aitemplate/testing/test_utils.py b/python/aitemplate/testing/test_utils.py index 07e43a6cd..82c1fe95d 100644 --- a/python/aitemplate/testing/test_utils.py +++ b/python/aitemplate/testing/test_utils.py @@ -30,7 +30,6 @@ from aitemplate.testing.detect_target import detect_target from aitemplate.utils.graph_utils import get_sorted_ops from aitemplate.utils.torch_utils import string_to_torch_dtype -from torch import nn class TestEnv(Enum): @@ -303,13 +302,13 @@ def get_attn_mask_per_causal_type( def init_random_weights(m): if hasattr(m, "weight"): - nn.init.uniform_(m.weight) + torch.nn.init.uniform_(m.weight) elif ( - type(m) == nn.Sequential - or type(m) == nn.ModuleList - or type(m) == nn.SiLU - or type(m) == nn.Dropout - or type(m) == nn.Identity + type(m) == torch.nn.Sequential + or type(m) == torch.nn.ModuleList + or type(m) == torch.nn.SiLU + or type(m) == torch.nn.Dropout + or type(m) == torch.nn.Identity ): pass else: