-
Notifications
You must be signed in to change notification settings - Fork 145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extremely slow processing on CPU #98
Comments
Hi 😊 thanks for the issue and the context. Maybe one thing we can do is to wrap the numpy fft function and selectively use it based on the hardware but it doesn't seem very simple. |
That's why I closed the pull request once #72, I would recommend to use larger batches |
Hi there! If you want we can use the workaround that @zaccharieramzi found, as it improves up to 10 times the actual implementation, at least for my use case, until tensorflow gives us a better approach! To do so, we have to include its implementation somewhere in the code: import multiprocessing
from tensorflow.python.framework import ops
from tensorflow.python.ops.signal import shape_ops, fft_ops, window_ops, spectral_ops
import numpy as np
from functools import partial
def parallel_stft(signals, frame_length, frame_step, fft_length=None,
window_fn=window_ops.hann_window,
pad_end=False, name=None):
with ops.name_scope(name, 'stft', [signals, frame_length, frame_step]):
signals = ops.convert_to_tensor(signals, name='signals')
signals.shape.with_rank_at_least(1)
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
frame_length.shape.assert_has_rank(0)
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
frame_step.shape.assert_has_rank(0)
if fft_length is None:
fft_length = spectral_ops._enclosing_power_of_two(frame_length)
else:
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
framed_signals = shape_ops.frame(
signals, frame_length, frame_step, pad_end=pad_end)
if window_fn is not None:
window = window_fn(frame_length, dtype=framed_signals.dtype)
framed_signals *= window
return tf.map_fn(
partial(fft_ops.rfft, fft_length=[fft_length]),
framed_signals,
fn_output_signature=tf.complex64,
parallel_iterations=multiprocessing.cpu_count(), # or how many parallel ops you see fit
) After that, change the call in the STFT layer from tf.signal.stft to parallel_stft |
@JPery Sounds like not a bad idea. To maximize its utility, we'd want this to work i) automatically when it's on cpu ii) without any complex configuration iii) when there is more than one item in the batch. maybe in Or, a conservative approach is to create another layer, maybe an inherited one from In either way, we need a carefully tested code, but I don't think I would have any time to work on this at least in 1-2 months. I'd love to review a PR :) |
I think this is an issue related to the tf.signal FFT implementation. It seems like it's using only a CPU core and it's extremely slow. Can we do anything to improve it?
PS: Thank you for your awesome work!
The text was updated successfully, but these errors were encountered: