Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd graph to propagate after multi-device syncing for loss functions in ddp #2754

Merged
merged 35 commits into from
Oct 31, 2024

Conversation

cw-tan
Copy link
Contributor

@cw-tan cw-tan commented Sep 17, 2024

What does this PR do?

Fixes #2745

Single-line enhancement proposed in #2745, that is, to enable the propagation of the autograd graph after the all_gather operation. This is useful for syncing loss functions in a ddp setting.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2754.org.readthedocs.build/en/2754/

@Borda
Copy link
Member

Borda commented Sep 17, 2024

That sounds good to me, but can we add a test for this enhancement?

@cw-tan
Copy link
Contributor Author

cw-tan commented Sep 17, 2024

That sounds good to me, but can we add a test for this enhancement?

Thanks for the prompt response @Borda.

I'm thinking that _test_ddp_gather_uneven_tensors (here) and _test_ddp_gather_uneven_tensors_multidim (here) in tests/unittests/bases/test_ddp.py already cover the correctness of gather_all_tensors. I'm not sure what other ddp tests there are, but those tests should help tell us if the change I made isn't breaking existing functionality. Let me know if you had something else in mind for this.

I can make an additional unittest in tests/unittests/bases/test_ddp.py to give a tensor that requires_grad to gather_all_tensors, compute some scalar from them (proxy for a loss), and compute grads two ways (one going through the all_gather, one that doesn't) and compare. So this tests that the change achieves the desired effect. How does that sound?

Copy link

codecov bot commented Sep 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 69%. Comparing base (abdd2c4) to head (5f29c4d).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #2754    +/-   ##
=======================================
- Coverage      69%     69%    -0%     
=======================================
  Files         344     330    -14     
  Lines       18824   18653   -171     
=======================================
- Hits        12971   12801   -170     
+ Misses       5853    5852     -1     

@Borda
Copy link
Member

Borda commented Sep 17, 2024

I can make an additional unittest in tests/unittests/bases/test_ddp.py to give a tensor that requires_grad to gather_all_tensors, compute some scalar from them (proxy for a loss), and compute grads two ways (one going through the all_gather, one that doesn't) and compare. So this tests that the change achieves the desired effect. How does that sound?

yeah, that sounds good to me :)

@Borda Borda added the enhancement New feature or request label Sep 17, 2024
@cw-tan cw-tan force-pushed the all_gather_ad branch 4 times, most recently from 6c926d7 to 1d0dabe Compare September 18, 2024 02:54
@cw-tan
Copy link
Contributor Author

cw-tan commented Sep 18, 2024

Update: to accommodate both cases where tensors from different ranks have the same/different shape, the line to put the original tensor (holding the AD graph) back into the gathered list was added in two places in the code.

Because of the two cases, I wrote two unittests to account for each. Interestingly, both pass 2.X stable, but for 1.X LTS, the "same shape" test passes but "different shape" test fails, and for 1.10 oldest, the "different shape" test passes but "same shape" test fails😅. I'll double check for bugs, but the actual code change is just two lines (and all other tests pass, so existing functionality still works), and the unittests are pretty short. The dependency of the unittests passing on different torch versions seems to indicate that it might be a torch versioning issue, maybe to do with ddp behavior? Any thoughts, @Borda ?

@Borda
Copy link
Member

Borda commented Sep 19, 2024

I wrote two unittests to account for each. Interestingly, both pass 2.X stable, but for 1.X LTS, the "same shape" test passes but "different shape" test fails, and for 1.10 oldest, the "different shape" test passes but "same shape" test fails😅.

that is strange and worse some more investigation...
cc: @SkafteNicki

Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked briefly why the tests do not pass on older versions of Pytorch but could not find a reason.

I think we should just only support this for Pytorch > 2.0 and then add this to the documentation.

src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
tests/unittests/bases/test_ddp.py Outdated Show resolved Hide resolved
tests/unittests/bases/test_ddp.py Outdated Show resolved Hide resolved
@cw-tan cw-tan force-pushed the all_gather_ad branch 2 times, most recently from dc35370 to e693ace Compare October 8, 2024 16:28
@Borda Borda requested a review from SkafteNicki October 8, 2024 17:26
@cw-tan cw-tan force-pushed the all_gather_ad branch 2 times, most recently from ce5dca1 to ffc67f6 Compare October 8, 2024 18:47
Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seeems the two test functions are now included twice in the test_ddp.py file, please check

src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
src/torchmetrics/utilities/distributed.py Show resolved Hide resolved
@SkafteNicki SkafteNicki added this to the v1.4.x milestone Oct 9, 2024
@mergify mergify bot removed the has conflicts label Oct 18, 2024
@SkafteNicki SkafteNicki modified the milestones: v1.4.x, v1.5.x Oct 21, 2024
@Borda Borda requested a review from baskrahmer October 25, 2024 07:53
@SkafteNicki
Copy link
Member

Alright, I finally sat down to understand what was going on here. The non-deterministic behavior was really strange to me, so I tried a lot of debugging and realized that order of output of the all_gather operation was not consistent. This meant that sometimes process 0 was the first element in the gathered list, sometimes the second e.g.

Process 0 gathered tensors: [tensor([0.]), tensor([1.])]  # rank and gathering order matches
Process 1 gathered tensors: [tensor([0.]), tensor([1.])]
Process 0 gathered tensors: [tensor([1.]), tensor([0.])]  # rank and gathering order do not match
Process 1 gathered tensors: [tensor([1.]), tensor([0.])]

I found out that the reason for this is that when the setup function (that initializes the ddp group) is called with another pool.map/pool.starmap than where the test is, this can lead to the processes not matches the expected order. This is exactly what we do in the test code currently

pool = Pool(processes=NUM_PROCESSES)
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES) for rank in range(NUM_PROCESSES)])
pytest.pool = pool

and then pytest.pool.starmap is called when we want to run a test in ddp mode. The solution was to call the setup function during the test function and everything is in the expected order. See this commit for details: 48e699b.
This has not been a problem before because we normally reduce all the states in some way e.g. sum them and then the order does not matter at all. Hopefully, this also means that it works regardless of Pytorch version.

@cw-tan sorry for the headache this must have been to debug. I have significantly simplified the tests you had for me to understand what was going on. Hope this still is fine with you.

@cw-tan
Copy link
Contributor Author

cw-tan commented Oct 31, 2024

@SkafteNicki Fantastic, thank you so much! I'm just excited to see this feature released so I can remove the monkeypatch in my own code to achieve the same effects. I think the docs are the remaining change -- still some details about only being PyTorch 2 compatible, but the tests have passed for other versions.

@mergify mergify bot added the ready label Oct 31, 2024
@Borda Borda changed the title Enable autograd graph to propagate after multi-device syncing (for loss functions in ddp) Enable autograd graph to propagate after multi-device syncing for loss functions in ddp Oct 31, 2024
@Borda Borda merged commit d3894e1 into Lightning-AI:master Oct 31, 2024
49 of 52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Autograd with DDP
3 participants