Skip to content

Commit

Permalink
Pad conv2d channel dim weight when 5 < CI < 8 (#849)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #849

CUDA conv channel dim weights need to align w/ a multiple of 2/4/8. If CI < 4, pad to 4; if 5 < CI < 8, pad to 8.

Reviewed By: henryhu6

Differential Revision: D47776430

fbshipit-source-id: cca790f9ff01651f85bbd78d30ca8eb8840c307b
  • Loading branch information
Colin Chan authored and facebook-github-bot committed Jul 27, 2023
1 parent 065aef4 commit fed47a6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,9 @@ def _choose_conv2d_op(
if last_dim < 4:
weight = pad_last_dim(len(weight._attrs["shape"]), 4)(weight)
x = pad_last_dim(len(x._attrs["shape"]), 4)(x)
elif last_dim > 4 and last_dim < 8:
weight = pad_last_dim(len(weight._attrs["shape"]), 8)(weight)
x = pad_last_dim(len(x._attrs["shape"]), 8)(x)
elif last_dim % 2 != 0:
return RuntimeError(
f"Conv2d is not implemented for input channel dim {last_dim}: it needs to be aligned to a multiple of 2/4/8"
Expand Down
6 changes: 4 additions & 2 deletions fx2ait/fx2ait/test/converters/test_ait_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TestConv2dConverter(AITTestCase):
param("non_unary_params", 3, 2, padding=1, bias=False),
param("dilation", 1, dilation=2),
param("multi_group", 1, 1, 1, 1, 3, bias=True),
param("in_channel_padding_gt_4_lt_8", 1, in_channel=7),
]
)
def test_conv2d(
Expand All @@ -40,21 +41,22 @@ def test_conv2d(
padding=0,
dilation=1,
groups=1,
in_channel=3,
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
3, 36, kernel_size, stride, padding, dilation, groups, bias
in_channel, 36, kernel_size, stride, padding, dilation, groups, bias
)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.conv(x))

model = TestModule().cuda().half()
inputs = [torch.randn(1, 3, 224, 224).cuda().half()]
inputs = [torch.randn(1, in_channel, 224, 224).cuda().half()]
self.run_test(
model,
inputs,
Expand Down

0 comments on commit fed47a6

Please sign in to comment.