Skip to content

Commit

Permalink
Fix linter error (#812)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #812

ATT

Reviewed By: kadeng

Differential Revision: D47211516

fbshipit-source-id: 5ae2a0ce358d817ba8d60d83c13ac5d40f36fdef
  • Loading branch information
aakhundov authored and facebook-github-bot committed Jul 4, 2023
1 parent 039bb9f commit d947681
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions python/aitemplate/testing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from aitemplate.testing.detect_target import detect_target
from aitemplate.utils.graph_utils import get_sorted_ops
from aitemplate.utils.torch_utils import string_to_torch_dtype
from torch import nn


class TestEnv(Enum):
Expand Down Expand Up @@ -303,13 +302,13 @@ def get_attn_mask_per_causal_type(

def init_random_weights(m):
if hasattr(m, "weight"):
nn.init.uniform_(m.weight)
torch.nn.init.uniform_(m.weight)
elif (
type(m) == nn.Sequential
or type(m) == nn.ModuleList
or type(m) == nn.SiLU
or type(m) == nn.Dropout
or type(m) == nn.Identity
type(m) == torch.nn.Sequential
or type(m) == torch.nn.ModuleList
or type(m) == torch.nn.SiLU
or type(m) == torch.nn.Dropout
or type(m) == torch.nn.Identity
):
pass
else:
Expand Down

0 comments on commit d947681

Please sign in to comment.