From d9476811eb1d35247003f4d8b57cf7107b3b12d7 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Tue, 4 Jul 2023 10:45:31 -0700 Subject: [PATCH] Fix linter error (#812) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/812 ATT Reviewed By: kadeng Differential Revision: D47211516 fbshipit-source-id: 5ae2a0ce358d817ba8d60d83c13ac5d40f36fdef --- python/aitemplate/testing/test_utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/aitemplate/testing/test_utils.py b/python/aitemplate/testing/test_utils.py index 07e43a6cd..82c1fe95d 100644 --- a/python/aitemplate/testing/test_utils.py +++ b/python/aitemplate/testing/test_utils.py @@ -30,7 +30,6 @@ from aitemplate.testing.detect_target import detect_target from aitemplate.utils.graph_utils import get_sorted_ops from aitemplate.utils.torch_utils import string_to_torch_dtype -from torch import nn class TestEnv(Enum): @@ -303,13 +302,13 @@ def get_attn_mask_per_causal_type( def init_random_weights(m): if hasattr(m, "weight"): - nn.init.uniform_(m.weight) + torch.nn.init.uniform_(m.weight) elif ( - type(m) == nn.Sequential - or type(m) == nn.ModuleList - or type(m) == nn.SiLU - or type(m) == nn.Dropout - or type(m) == nn.Identity + type(m) == torch.nn.Sequential + or type(m) == torch.nn.ModuleList + or type(m) == torch.nn.SiLU + or type(m) == torch.nn.Dropout + or type(m) == torch.nn.Identity ): pass else: