Skip to content

Commit

Permalink
common: sdpa: enable ONEDNN_ENABLE_PRIMITIVE selection
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Jun 4, 2024
1 parent dc301db commit 25596d2
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmake/configuring_primitive_list.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ else()
foreach(impl ${DNNL_ENABLE_PRIMITIVE})
string(TOUPPER ${impl} uimpl)
if(NOT "${uimpl}" MATCHES
"^(BATCH_NORMALIZATION|BINARY|CONCAT|CONVOLUTION|DECONVOLUTION|ELTWISE|INNER_PRODUCT|LAYER_NORMALIZATION|LRN|MATMUL|POOLING|PRELU|REDUCTION|REORDER|RESAMPLING|RNN|SHUFFLE|SOFTMAX|SUM)$")
"^(BATCH_NORMALIZATION|BINARY|CONCAT|CONVOLUTION|DECONVOLUTION|ELTWISE|INNER_PRODUCT|LAYER_NORMALIZATION|LRN|MATMUL|POOLING|PRELU|REDUCTION|REORDER|RESAMPLING|RNN|SDPA|SHUFFLE|SOFTMAX|SUM)$")
message(FATAL_ERROR "Unsupported primitive: ${uimpl}")
endif()
set(BUILD_${uimpl} TRUE)
Expand Down
4 changes: 2 additions & 2 deletions cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ set(DNNL_ENABLE_PRIMITIVE "ALL" CACHE STRING
- <PRIMITIVE_NAME>. Includes only the selected primitive to be enabled.
Possible values are: BATCH_NORMALIZATION, BINARY, CONCAT, CONVOLUTION,
DECONVOLUTION, ELTWISE, INNER_PRODUCT, LAYER_NORMALIZATION, LRN, MATMUL,
POOLING, PRELU, REDUCTION, REORDER, RESAMPLING, RNN, SHUFFLE, SOFTMAX,
SUM.
POOLING, PRELU, REDUCTION, REORDER, RESAMPLING, RNN, SDPA, SHUFFLE,
SOFTMAX, SUM.
- <PRIMITIVE_NAME>;<PRIMITIVE_NAME>;... Includes only selected primitives to
be enabled at build time. This is treated as CMake string, thus, semicolon
is a mandatory delimiter between names. This is the way to specify several
Expand Down
10 changes: 5 additions & 5 deletions doc/build/build_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ This option supports several values: `ALL` (the default) which enables all
primitives implementations or a set of `BATCH_NORMALIZATION`, `BINARY`,
`CONCAT`, `CONVOLUTION`, `DECONVOLUTION`, `ELTWISE`, `INNER_PRODUCT`,
`LAYER_NORMALIZATION`, `LRN`, `MATMUL`, `POOLING`, `PRELU`, `REDUCTION`,
`REORDER`, `RESAMPLING`, `RNN`, `SHUFFLE`, `SOFTMAX`, `SUM`. When a set is used,
only those selected primitives implementations will be available. Attempting to
use other primitive implementations will end up returning an unimplemented
status when creating primitive descriptor. In order to specify a set, a
CMake-style string should be used, with semicolon delimiters, as in this
`REORDER`, `RESAMPLING`, `RNN`, `SDPA`, `SHUFFLE`, `SOFTMAX`, `SUM`. When a set
is used, only those selected primitives implementations will be available.
Attempting to use other primitive implementations will end up returning an
unimplemented status when creating primitive descriptor. In order to specify a
set, a CMake-style string should be used, with semicolon delimiters, as in this
example:
```
-DONEDNN_ENABLE_PRIMITIVE=CONVOLUTION;MATMUL;REORDER
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl_config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
#cmakedefine01 BUILD_REORDER
#cmakedefine01 BUILD_RESAMPLING
#cmakedefine01 BUILD_RNN
#cmakedefine01 BUILD_SDPA
#cmakedefine01 BUILD_SHUFFLE
#cmakedefine01 BUILD_SOFTMAX
#cmakedefine01 BUILD_SUM
Expand Down
7 changes: 7 additions & 0 deletions src/common/impl_registration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@
{}
#endif

#if BUILD_PRIMITIVE_ALL || BUILD_SDPA
#define REG_SDPA_P(...) __VA_ARGS__
#else
#define REG_SDPA_P(...) \
{}
#endif

#if BUILD_PRIMITIVE_ALL || BUILD_SHUFFLE
#define REG_SHUFFLE_P(...) __VA_ARGS__
#else
Expand Down
4 changes: 2 additions & 2 deletions src/gpu/gpu_sdpa_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ namespace gpu {
namespace {

// clang-format off
constexpr impl_list_item_t impl_list[] = {
constexpr impl_list_item_t impl_list[] = REG_SDPA_P({
GPU_INSTANCE_INTEL(intel::ocl::micro_sdpa_t)
GPU_INSTANCE_INTEL_DEVMODE(intel::ocl::ref_sdpa_t)
nullptr,
};
});
// clang-format on
} // namespace

Expand Down

0 comments on commit 25596d2

Please sign in to comment.