diff --git a/ghostnetv2_pytorch/model/ghostnetv2_torch.py b/ghostnetv2_pytorch/model/ghostnetv2_torch.py index ba58f19..44f5842 100644 --- a/ghostnetv2_pytorch/model/ghostnetv2_torch.py +++ b/ghostnetv2_pytorch/model/ghostnetv2_torch.py @@ -121,7 +121,7 @@ def forward(self, x): x1 = self.primary_conv(x) x2 = self.cheap_operation(x1) out = torch.cat([x1,x2], dim=1) - return out[:,:self.oup,:,:]*F.interpolate(self.gate_fn(res),size=out.shape[-1],mode='nearest') + return out[:,:self.oup,:,:]*F.interpolate(self.gate_fn(res),size=(out.shape[-2],out.shape[-1]),mode='nearest') class GhostBottleneckV2(nn.Module):