Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Generalize pytorch content for non-native device execution #66

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ankurneog
Copy link

@ankurneog ankurneog commented Aug 12, 2024

Motivation

This PR is for review of the following RFC : Modify PyTorch framework UTs so that non-cuda devices such as intel Gaudi and intel XPU is able to harness the content and improve quality.
https://github.com/pytorch/rfcs/pull/66/files

@facebook-github-bot
Copy link
Contributor

Hi @ankurneog!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!


This will also ensure greater participation for content enhancement.

## **Proposed Implementation**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea makes a lot of sense.
I think we would need more details and how this interracts with existing features like the device-generic tests (https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_device_type.py) that already work for privateuse1 partially btw and the opinfo consistency tests.

Also I don't think we want to aim at running the full test suite with the side device available but select specific device-dependent tests that need to be ran for each device we support.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD : thanks for your comment , yes I believe we need some extensive hooks to enable this , for eg. the one we introduced with
https://github.com/pytorch/pytorch/pull/128584/files#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R531
@onlyNativeDeviceTypesAnd(["hpu"])
The other devices can add to to such list, if it supports the TC.

we can modify other hooks like skipIfDevice in similar fashion.

The common_device_type is useable if we replace the onlyNativeDeviceType decorator.
It was widely used in the initial files but all recent files are not using it (eg: dynamo/distributed) and instead directly make .cuda() calls.

However these content shouldn't be too difficult to migrate.

I believe in general we should ensure new test content uses the common_device_type framework, and open up the content for "non-native" device execution.

* Dtypes for a device should be dynamically loaded per op based on a common dictionary, instead of using different variables per device , eg: dtypesIfCuda
* Miscelleneous decorators such as @skipIfCuda should be generalized @skipIfDevice
* Extend use of instantiate_device_type for all content, so that developers are forced to use generalized device code rather than using "cuda" or "cpu"
* Generalize common distributed content , so that it can be extended for non nccl backends such as intel's hccl and ccl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @wconstab if there is anyone on the distributed side that can help with designing this?

* Miscelleneous decorators such as @skipIfCuda should be generalized @skipIfDevice
* Extend use of instantiate_device_type for all content, so that developers are forced to use generalized device code rather than using "cuda" or "cpu"
* Generalize common distributed content , so that it can be extended for non nccl backends such as intel's hccl and ccl
* Generalize the dynamo content for specific backends which other devices might want to verify with existing content.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this means? But I guess Dynamo tests should happen the same way as other tests (split between device-generic and others and we only re-run the device-generic tests)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD : yes you are right, what i meant is - for the devices there should be flexibility to execute the TC based backends the device supports ( eg: aot_eager, cudagraphs,aot_ts) and also ability to add custom backends which are not part of the tree for eg: intel's hpu_backend

* Extend use of instantiate_device_type for all content, so that developers are forced to use generalized device code rather than using "cuda" or "cpu"
* Generalize common distributed content , so that it can be extended for non nccl backends such as intel's hccl and ccl
* Generalize the dynamo content for specific backends which other devices might want to verify with existing content, the backends should always be extracted from
a list that is abstracted out and the list can be appended per device per TC.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good list of items to generalize the test cases. Does the proposal just focus on the devices having dedicated device tags installed in PyTorch core or also support the PrivateUse1 device which is used to extend PyTorch with any out of the tree devices?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 : Thanks for your comment, Since i investigated mostly in lines to support intel Gaudi, which has dedicated device tag, i have not checked the impact or support needed for PrivateUse1 devices

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @FFFrog @Yikun have more thoughts on this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 @ankurneog Sorry for the late reply.

In theory, in the test framework of PyTorch, dedicated keys are almost the same as public keys (PrivateUse1), and PrivateUse1 is already supported in the test framework.

First of all, I can`t agree more with this proposal, because Ascend NPU is currently facing the above-described problems; The solution proposed by @ankurneog can solve most of the problems we have encountered; We are currently sorting out all the problems encountered, and will add them to this RFC later, and hope that the new stuff we will add will help the proposal be more complete.

By the way, if possible, we can work together to complete this proposal and make it land in PyTorch :D

@ankurneog ankurneog marked this pull request as ready for review August 20, 2024 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants