Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC for reading cuts in background thread in dynamic bucketing #680

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

pzelasko
Copy link
Collaborator

@pzelasko pzelasko commented Apr 19, 2022

@danpovey it may address the issue described in #678; but I haven't tested it beyond running unit tests successfully. I added a background thread for collect_cuts_in_buckets. Threading should be sufficient, as I expect the main process CPU to be mostly idle during forward passes on GPU. This implementation should be stable but I don't think it covers every possible edge case of multithreading hazards.. I might transform this into a full blown thread-safe queue kind of thing if you can confirm this helps with the training speed (or ends up in a deadlock when run in real training...)

EDIT: I'm not even 100% sure that the mutex is needed at all..

@csukuangfj
Copy link
Contributor

Thanks! I will test it and post the training time with and without this PR.

@csukuangfj
Copy link
Contributor

It throws the following exception

2022-04-19 16:48:07,508 INFO [train.py:1069] (2/8) Sanity check -- see if any of the batches in epoch 0 would cause OOM.
2022-04-19 16:48:07,611 INFO [asr_datamodule.py:266] (6/8) About to create dev dataset
2022-04-19 16:48:07,623 INFO [asr_datamodule.py:285] (6/8) About to create dev dataloader
2022-04-19 16:48:07,624 INFO [train.py:1069] (6/8) Sanity check -- see if any of the batches in epoch 0 would cause OOM.
2022-04-19 16:48:24,925 INFO [train.py:1010] (0/8) Loading grad scaler state dict
2022-04-19 16:48:24,928 INFO [train.py:1010] (3/8) Loading grad scaler state dict
2022-04-19 16:48:24,929 INFO [train.py:1010] (6/8) Loading grad scaler state dict
2022-04-19 16:48:24,930 INFO [train.py:1010] (7/8) Loading grad scaler state dict
2022-04-19 16:48:24,930 INFO [train.py:1010] (2/8) Loading grad scaler state dict
2022-04-19 16:48:24,932 INFO [train.py:1010] (4/8) Loading grad scaler state dict
2022-04-19 16:48:24,936 INFO [train.py:1010] (5/8) Loading grad scaler state dict
2022-04-19 16:48:24,977 INFO [train.py:1010] (1/8) Loading grad scaler state dict
Traceback (most recent call last):
  File "./pruned_transducer_stateless3/train.py", line 1123, in <module>
    main()
  File "./pruned_transducer_stateless3/train.py", line 1114, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/ceph-fj/fangjun/software/py38/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/ceph-fj/fangjun/software/py38/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/ceph-fj/fangjun/software/py38/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 7 terminated with the following error:
Traceback (most recent call last):
  File "/ceph-fj/fangjun/software/py38/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/ceph-fj/fangjun/open-source-2/icefall-multi-2/egs/librispeech/ASR/pruned_transducer_stateless3/train.py", line 1023, in run
    train_one_epoch(
  File "/ceph-fj/fangjun/open-source-2/icefall-multi-2/egs/librispeech/ASR/pruned_transducer_stateless3/train.py", line 840, in train_one_epoch
    loss_value = tot_loss["loss"] / tot_loss["frames"]
ZeroDivisionError: division by zero

@@ -284,6 +286,9 @@ def __init__(
deque() for _ in range(len(duration_bins) + 1)
]

self._cut_reading_thread = ThreadPoolExecutor(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to not use a process pool? Due to the global interpreter lock, there can be only one running thread at any given time in Python, I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with some setups that use IterableDatasetWrapper you are placing the sampler in a dataloader worker process, and AFAIK you can't spawn a nested process pool there because that process is daemonic.

Anyway thread should be sufficient here as I expect the CPU to be mostly idle when running forward and backward passes on GPUs... The reason it didn't work for you is likely the thread could not populate the buckets fast enough and sampler thought they are depleted (race condition). This can be solved with a proper synchronization mechanism but unfortunately I don't have the time to add it right now. I'll return to it sometime.

@danpovey
Copy link
Collaborator

So do we know how the num-frames could be zero?
I don't know how the forward() could have succeeded if the tensors were empty.

@pzelasko
Copy link
Collaborator Author

I suppose the sampler yielded an empty cutset, the dataset somehow didn't crash and collated an empty tensor. It can be fixed with proper synchronization between threads, but to check really quickly if it works, it could be enough to put time.sleep(5) after this line of code to allow the buckets to be populated at the start of __iter__ before their consumption

self._collect_cuts_in_buckets(self.buffer_size)

@pzelasko
Copy link
Collaborator Author

... when I have more time again, I'll take care of it and test it end-to-end.

@pzelasko
Copy link
Collaborator Author

@danpovey @csukuangfj please check if it is faster now (I checked that it does synchronize correctly with the latest changes). In quick local testing I could not see any difference, but maybe you will notice some in your setup.

@csukuangfj
Copy link
Contributor

@danpovey @csukuangfj please check if it is faster now (I checked that it does synchronize correctly with the latest changes). In quick local testing I could not see any difference, but maybe you will notice some in your setup.

I will test it when we have free GPUs.

@SongLi89
Copy link

SongLi89 commented Feb 6, 2024

Hi, Is this PR merged? Maybe I have similar problems, that reading the Cuts is not so fast.

@pzelasko
Copy link
Collaborator Author

pzelasko commented Feb 9, 2024

No, it hasn’t been merged — I didn’t find any difference with this implementation in quick testing. Can you describe your environment a bit more? What’s your sampler, max_duration, num_workers, data size, are you reading audio or features, etc. Also I recommend running py-spy on your script (or dataloading worker processes) to understand where the time is being spent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants