Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Something weird between Tensorflow-offline-wpe and numpy-offline-wpe #33

Open
YongyuG opened this issue Sep 5, 2019 · 6 comments
Open

Comments

@YongyuG
Copy link

YongyuG commented Sep 5, 2019

Hi thanks for your works on nara_wpe, I learn quiet a lot from your implementation and your paper,

I tried to integrated tensorflow-offline-wpe with my ASR system,
However, time spend on a 3.5s audio for tf-offline-wpe is ~7s while the numpy version only takes ~200ms

I tried tf-offline-wpe in GPU. What I have done is just do the wpe dreverberation processing under a tf.session for all the audio file
so my code is something like

with tf.Session(config=config) as session:
    with tf.device('/gpu:0'):
        <wpe for all the audio file in test_data_set>

But it takes more time than numpy versio which confuses me a lot. I expect the tf nara_wpe version will be faster than the numpy one

Best

@boeddeker
Copy link
Member

This difference sounds to be too large.

Nevertheless, there are some reasons, why numpy is/can be faster:

  • The tensorflow code is similar to on an older numpy implementation, that we dropped, because the new numpy implementation is significant faster
  • The tensorflow code can not be ported to the new numpy implementation, because the speed gain for numpy has its origin in the np.lib.stride_tricks.as_strided function, that has no counterpart in tensorflow (https://docs.scipy.org/doc/numpy/reference/generated/numpy.lib.stride_tricks.as_strided.html)
  • The tensorflow session has a start up time and needs to copy the data from numpy to tensorflow and later from tensorflow to numpy
  • The CPU solve function from tensorflow is used. The GPU Version do not produce a valid output in Tensorflow 1.12.
    with tf.device('/cpu:0'):

Can you provide a complete toy example?
As input, you could use random numbers.

Small hint:
Be carefull, when you use Tensorflow 1.13 or newer. There is a bug that sometimes the conjugate is ignored (tensorflow/tensorflow#27500)

@jheymann85
Copy link
Member

jheymann85 commented Sep 5, 2019

As an additional note: be careful where you call WPE (=build the graph) and where you call session.run(...). The graph should be built outside any loop.

@YongyuG
Copy link
Author

YongyuG commented Sep 6, 2019

As an additional note: be careful where you call WPE (=build the graph) and where you call session.run(...). The graph should be built outside any loop.

Thanks for your reply,
However, I want to do tensorflow wpe process for all audio in a dataset directory, so I have to loop every single audio file of this data directory.

So my code is something like, I think maybe I should create a TFrecords or I have to restart session for each single audio file:

    Y_tf = tf.placeholder(tf.complex128, shape=(None, None, None))
    Z_tf = wpe(Y_tf, taps=taps, iterations=iterations)
    with tf.Session() as session:
        for root, dirs, files in os.walk(dataPath):
            for f in files:
                in_file_path = os.path.join(root, f)
                old_dir_name = root.split('data')[1].split('/')[1]
                new_root = root.replace(old_dir_name, outDirName)
                out_file_path = os.path.join(new_root, f)
                if f.endswith('wav'):
                        if not os.path.exists(new_root):
                            os.makedirs(new_root)
                        input_file = os.path.join(root, f)
                        output_file = os.path.join(new_root, f)
                        data, sfr = sf.read(input_file)
                        y = np.reshape(data, (1, data.shape[0]))
                        Y = stft(y, **stft_options).transpose(2, 0, 1)
                        Z = session.run(Z_tf, {Y_tf: Y})
                        z = istft(Z.transpose(1, 2, 0), size=512, shift=128)

Best

@YongyuG
Copy link
Author

YongyuG commented Sep 6, 2019

This difference sounds to be too large.

Nevertheless, there are some reasons, why numpy is/can be faster:

* The tensorflow code is similar to on an older numpy implementation, that we dropped, because the new numpy implementation is significant faster

* The tensorflow code can not be ported to the new numpy implementation, because the speed gain for numpy has its origin in the `np.lib.stride_tricks.as_strided` function, that has no counterpart in tensorflow (https://docs.scipy.org/doc/numpy/reference/generated/numpy.lib.stride_tricks.as_strided.html)

* The tensorflow session has a start up time and needs to copy the data from numpy to tensorflow and later from tensorflow to numpy

* The CPU `solve` function from tensorflow is used. The GPU Version do not produce a valid output in Tensorflow 1.12. https://github.com/fgnt/nara_wpe/blob/89db94e4938d4afc68736022819852e1bbea95e2/nara_wpe/tf_wpe.py#L256

Can you provide a complete toy example?
As input, you could use random numbers.

Small hint:
Be carefull, when you use Tensorflow 1.13 or newer. There is a bug that sometimes the conjugate is ignored (tensorflow/tensorflow#27500)

Thanks for your interpretations and details,
My tensorflow version is 1.13, did you mean the conjugate computation while calculating R is ignored?
I want to loop each audio file from a dataset folder.
Moreover, I tried to restart session for each audiofile which causes slower process.
My example is like. and I do that TF_wpe_process for single channel audio:

    Y_tf = tf.placeholder(tf.complex128, shape=(None, None, None))
    Z_tf = wpe(Y_tf, taps=taps, iterations=iterations)
    with tf.Session() as session:
        for root, dirs, files in os.walk(dataPath):
            for f in files:
                in_file_path = os.path.join(root, f)
                old_dir_name = root.split('data')[1].split('/')[1]
                new_root = root.replace(old_dir_name, outDirName)
                out_file_path = os.path.join(new_root, f)
                if f.endswith('wav'):
                        if not os.path.exists(new_root):
                            os.makedirs(new_root)
                        input_file = os.path.join(root, f)
                        output_file = os.path.join(new_root, f)
                        data, sfr = sf.read(input_file)
                        y = np.reshape(data, (1, data.shape[0]))
                        Y = stft(y, **stft_options).transpose(2, 0, 1)
                        Z = session.run(Z_tf, {Y_tf: Y})
                        z = istft(Z.transpose(1, 2, 0), size=512, shift=128)

Best

@YongyuG
Copy link
Author

YongyuG commented Sep 6, 2019

In addition, I looked into the source code of tf_wpe,
as you metioned, the function get_filter_matrix_conj from tf_wpe.py
which I think it is to computethe G matrix mentioned in Paper is still used CPU

  elif mode == 'solve':
        with tf.device('/cpu:0'):
            stacked_filter_conj = tf.reshape(
                tf.matrix_solve(
                    tf.tile(correlation_matrix[None, ...], [D, 1, 1]),
                    tf.reshape(correlation_vector, (D, D * taps, 1))
                ),
                (D * D * taps, 1)
            )

And unlike the numpy implementation, the tensorflow_wpe do the wpe process for each frequency point through whole frequency bins, and I think it's the reason why it perform slower

So if I really want to make wpe process work faster(3.5audio, expect 100ms time spend), do I need to re-program numpy version in C, or I can also develop the tf_wpe.py, let some matrix computation do it on GPU, and some sequence computation in CPU

Best

@boeddeker
Copy link
Member

My tensorflow version is 1.13, did you mean the conjugate computation while calculating R is ignored?

This can happen. The problem is that tensorflow does some optimizations of the graph where a conjugate may be ignored. So the solution can be correct or incorrect.

Moreover, I tried to restart session for each audiofile which causes slower process.
That's what jheymann85 explained. Your code looks correct. One session and you build the graph one time.

And unlike the numpy implementation, the tensorflow_wpe do the wpe process for each frequency point through whole frequency bins, and I think it's the reason why it perform slower

For WPE this shouldn't be a problem. For the numpy version we have also a "batched" (wpe_v7) and a loopy/for each frequency (wpe_v8) version. Counter intuitively the loopy version is in many situations faster than the batched version.
Can you test if the tf version is faster on your CPU than on the GPU?

So if I really want to make wpe process work faster(3.5audio, expect 100ms time spend), do I need to re-program numpy version in C, or I can also develop the tf_wpe.py, let some matrix computation do it on GPU, and some sequence computation in CPU

When I see your example code, it would also be possible to simply take the numpy wpe code.
Reprogramming the numpy version in C will be difficult, because we used manipulated strides.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants