-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add quantized_maxpool_2d for xpu #1049
base: main
Are you sure you want to change the base?
Conversation
The specific unit test cases have passed in latest CI. 2024-11-07T09:10:50.0722836Z quantization/core/test_quantized_op_xpu.py::TestQuantizedOpsXPU::test_max_pool2d_nhwc_xpu PASSED [ 48%]
2024-11-07T09:10:50.1551494Z quantization/core/test_quantized_op_xpu.py::TestQuantizedOpsXPU::test_max_pool2d_pt2e_xpu PASSED [ 50%]
2024-11-07T09:10:50.5847030Z quantization/core/test_quantized_op_xpu.py::TestQuantizedOpsXPU::test_max_pool2d_xpu PASSED [ 51%] |
w_start += dW_; | ||
|
||
// Stock pytorch's cpu implementation use vectorized instructions | ||
// through channels such as AVX-512. We use for-loop directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for confirmation, the work-item indexing is optimized for cl right? I mean, tensor is cl, and the inner-most dim of work-item is on channels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not optimized for cl. The innter-most dim is nbatch like DilatedMaxPool2d implementation. On our gpu we do not have vectorized method, so I use unrolled loop to simulate it. If compiler compiled and optimized the loop to vectorized codes, it will help. I just followed the stock's implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great thanks for your explanations. I am OK that we can keep this implementation. Functionality is preferred now.
hi, @EikanWang FYI, if perf is also important for us currently, could we add further optimization after this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use TORCH_CHECK before launch kernels for dtype check.
Now we only support datatype of uint8(Byte). Referring the stock pytorch cpu implementation at code.
Waiting #921 to be merged.