diff --git a/python/aitemplate/frontend/nn/attention.py b/python/aitemplate/frontend/nn/attention.py index 7c9f36279..d4e174f9c 100644 --- a/python/aitemplate/frontend/nn/attention.py +++ b/python/aitemplate/frontend/nn/attention.py @@ -257,7 +257,7 @@ def attention(self, x): ) return out - def forward(self, *args, seqlens=None): + def forward(self, *args): """forward pass for calling mha module""" assert len(args) >= 1 x = args[0]