Skip to content

Commit

Permalink
Merge pull request #27 from NVlabs/dev2
Browse files Browse the repository at this point in the history
enable dynamic onnx batch size
  • Loading branch information
ahatamiz authored Jun 23, 2023
2 parents f605722 + bab8b67 commit bb42f88
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
9 changes: 4 additions & 5 deletions fastervit/models/faster_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ def window_partition(x, window_size):
return windows


def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
def window_reverse(windows, window_size, H, W, B):
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], H, W)
return x


Expand Down Expand Up @@ -851,7 +850,7 @@ def forward(self, x):
for bn, blk in enumerate(self.blocks):
x, ct = blk(x, ct)
if self.transformer_block:
x = window_reverse(x, self.window_size, H, W)
x = window_reverse(x, self.window_size, H, W, B)
if self.downsample is None:
return x
return self.downsample(x)
Expand Down
9 changes: 4 additions & 5 deletions fastervit/models/faster_vit_any_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ def window_partition(x, window_size):
return windows


def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
def window_reverse(windows, window_size, H, W, B):
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], H, W)
return x


Expand Down Expand Up @@ -874,7 +873,7 @@ def forward(self, x):
x, ct = blk(x, ct)
if self.transformer_block:
x = window_reverse(x, self
.window_size, Hp, Wp)
.window_size, Hp, Wp, B)
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
if self.downsample is None:
Expand Down

0 comments on commit bb42f88

Please sign in to comment.