-
Notifications
You must be signed in to change notification settings - Fork 363
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor softmax code for dim=-1 case (#886)
Summary: Pull Request resolved: #886 This sets the stage for extending the warp reduction code to the dim=-2 case. Main changes in this diff: 1. blockReduceMax now uses `fast_max` instead of `max`. This is what the dim=-2 reduction code already uses. From looking at the implementation, it seems that fast_max is fast because of type specialization and not because it sacrifices accuracy, so I think this is a safe change. 2. The block reduction code had a `NUM` parameter that was always set to 1. I've eliminated that to simplify things and remove a bunch of indirection. 3. The shared memory used by the block / warp reduction code had shared memory set to rows of 33 elements in order to avoid bank conflicts. However, given that NUM was always 1, I don't think bank conflicts are possible in practice. I've therefore changed the shared memory to use rows of 32 elements instead. Reviewed By: muchulee8, aakhundov Differential Revision: D47862053 fbshipit-source-id: 477f16ae3f33e2e5f2a858bf156268aef5a25ee1
- Loading branch information
1 parent
318111f
commit 029ba1c
Showing
1 changed file
with
59 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters