Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add istft operation #2029

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
26 changes: 25 additions & 1 deletion coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6362,6 +6362,7 @@ def stft(context, node):
Lowers torch.stft with the dialect op `complex_stft` from complex_dialect_ops.py
"""
input_data, n_fft, hop_length, win_length, window, normalized, onesided, _ = _get_inputs(context, node, min_expected=2)

if types.is_complex(input_data.dtype):
onesided = False # pytorch defaults onesided to False for complex inputs
stft_res = mb.complex_stft(
Expand All @@ -6371,9 +6372,32 @@ def stft(context, node):
win_length=win_length,
window=window,
normalized=normalized,
onesided=onesided)
onesided=onesided
)
context.add(stft_res, node.name)

@register_torch_op
def istft(context, node):
"""
Lowers torch.istft with the dialect op `complex_istft` from complex_dialect_ops.py
"""
input_data, n_fft, hop_length, win_length, window, center, normalized, onesided, length, _ = _get_inputs(context, node, min_expected=2)

if types.is_complex(input_data.dtype):
onesided = False # pytorch defaults onesided to False for complex inputs
istft_res = mb.complex_istft(
input=input_data,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
length=length,
)
context.add(istft_res, node.name)

@register_torch_op(torch_alias=["torchvision::nms"])
def torchvision_nms(context, node):
inputs = _get_inputs(context, node, expected=3)
Expand Down
86 changes: 83 additions & 3 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9540,8 +9540,8 @@ def forward(self, x):
(2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit
)


class TestSTFT(TorchBaseTest):
@pytest.mark.slow
@pytest.mark.parametrize(
"compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided",
itertools.product(
Expand All @@ -9566,9 +9566,8 @@ def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_leng
class STFTModel(torch.nn.Module):
def forward(self, x):
applied_window = window(win_length) if window and win_length else None
x = torch.complex(x, x) if complex else x
x = torch.stft(
x,
torch.complex(x, x) if complex else x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
Expand All @@ -9588,6 +9587,87 @@ def forward(self, x):
compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex",
itertools.product(
compute_units,
backends,
[None, 1, 3], # channels
[16, 32], # n_fft
[5, 9], # num_frames
[None, 5], # hop_length
[None, 10, 8], # win_length
[None, torch.hann_window], # window
[False, True], # center
[False, True], # normalized
[None, False, True], # onesided
[None, "shorter", "larger"], # length
[False, True], # return_complex
)
)
def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
if return_complex and onesided:
pytest.skip("Complex output is incompatible with onesided")

if hop_length is None and win_length is not None:
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")

# Compute input_shape to generate test case
freq = n_fft//2+1 if onesided else n_fft
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)

# If not set,c ompute hop_length for capturing errors
if hop_length is None:
hop_length = n_fft // 4

if length == "shorter":
length = n_fft//2 + hop_length * (num_frames - 1)
elif length == "larger":
length = n_fft*3//2 + hop_length * (num_frames - 1)

class ISTFTModel(torch.nn.Module):
def forward(self, x):
applied_window = window(win_length) if window and win_length else None
x = torch.istft(
torch.complex(x, x),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=applied_window,
center=center,
normalized=normalized,
onesided=onesided,
length=length,
return_complex=return_complex)
if return_complex:
return torch.stack([torch.real(x), torch.imag(x)], dim=0)
else:
return torch.real(x)

if (center is False and win_length) or (center and win_length and length):
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
TorchBaseTest.run_compare_torch(
input_shape,
ISTFTModel(),
backend=backend,
compute_unit=compute_unit
)
elif length and return_complex:
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
TorchBaseTest.run_compare_torch(
input_shape,
ISTFTModel(),
backend=backend,
compute_unit=compute_unit
)
else:
TorchBaseTest.run_compare_torch(
input_shape,
ISTFTModel(),
backend=backend,
compute_unit=compute_unit
)

if _HAS_TORCH_AUDIO:

Expand Down
77 changes: 77 additions & 0 deletions coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,80 @@ def type_inference(self):

return types.tensor(output_type, tuple(output_shape))

@register_op(namespace="complex")
class complex_istft(Operation):
"""
Dialect op for 1-D ISTFT.

Parameters
----------
input: tensor<\*V, complex64> (Required)
* A complex tensor where real and imag parts have the same shape.
n_fft: const i32 (Required)
* Size of the fourier transform.
hop_length: const i32 (Optional)
* Stride between window frames of the input tensor.
win_length: const i32 (optional)
* The size of the window frame.
window: tensor<1, win_length> (optional)
* The window to apply to the input signal before performing the fourier transform.
normalized: const bool (optional, Default=``false``)
* Whether to normalize the results of the STFT
onesided: const bool (optional, Default=``true``)
* Whether the STFT was onesieded
length: const i32 (Required)
* Output fixed length, which will be zeropadded


Returns
-------
tensor<\*D, T>
* The output tensor

Attributes
----------
T: fp32, complex64

References
----------
See `torch.istft <https://pytorch.org/docs/2.0/generated/torch.istft.html>`_.
"""

input_spec = InputSpec(
input=TensorInputType(type_domain=types.complex),
n_fft=TensorInputType(const=True, type_domain=types.int32),
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
window=TensorInputType(const=True, optional=True, type_domain=types.fp32),
center=TensorInputType(const=True, type_domain=types.bool),
normalized=TensorInputType(const=True, optional=False, type_domain=types.bool),
onesided=TensorInputType(const=True, optional=True, type_domain=types.bool),
length=TensorInputType(const=True, optional=True, type_domain=types.int32),
return_complex=TensorInputType(const=True, optional=True, type_domain=types.bool),
)

def default_inputs(self):
return DefaultInputs(
hop_length = None,
win_length = None,
window = None,
normalized = False,
onesided = True,
length = None,
return_complex = True,
)

def type_inference(self):
output_type = (types.complex64) if self.return_complex else (types.fp32)

# add batch size if given
output_shape = [self.input.shape[0] if self.input.rank == 3 else 1]

if self.length:
output_shape += [self.length]
else:
n_frames = self.input.shape[-1]
hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4
output_shape += [self.n_fft.val + hop_length * (n_frames - 1)]

return types.tensor(output_type, tuple(output_shape))
Loading