Skip to content

Commit

Permalink
Support dim != 1 for softmax w/o using permute (#845)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #845

This is a port of PyTorch's softmax implementation.

Notable differences:
* We use fast_exp & fast_max instead of std::max and std::exp
* We don't use higher-precision types for accumulator values (doesn't look like the dim=-1 softmax code does this either)
* We propagate the reduction dim size & inner size as constants

We seem to be very marginally slower than PT for small batch sizes and very marginally faster for large ones.

I have named this new softmax implementation "softmaxGeneral" since it is able to handle arbitrary reduction dimensions, even though we are only using it for the `dim > 1` case.

Reviewed By: aakhundov

Differential Revision: D47732875

fbshipit-source-id: f57fb81921a8f6da48abe5b61542e09f3142e4ed
  • Loading branch information
int3 authored and facebook-github-bot committed Aug 8, 2023
1 parent b65dab4 commit 318111f
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 51 deletions.
119 changes: 119 additions & 0 deletions python/aitemplate/backend/cuda/softmax/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ __inline__ __device__ bfloat16 fast_max(const bfloat16 a, const bfloat16 b) {

#endif

template <typename T>
struct FastMax {
__device__ __forceinline__ T operator()(T a, T b) const {
return fast_max(a, b);
}
};

template <typename T>
__inline__ __device__ T Inf();

Expand Down Expand Up @@ -854,4 +861,116 @@ void LaunchSoftmaxK1Middle(
SOFTMAX_LAUNCH_CHECK();
}

// Note that it's not a complete block-wide reduction.
// Only threads that share threadIdx.y reduce values.
template <typename T, template <typename> class ReduceOp>
__forceinline__ __device__ T softmax_general_block_reduce_x(T* shared, T val) {
ReduceOp<T> r;
shared += threadIdx.y * blockDim.x;

__syncthreads();

shared[threadIdx.x] = val;

// NOTE: loop starts with __syncthreads()
int offset = blockDim.x / 2;
while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset)
shared[threadIdx.x] =
r(shared[threadIdx.x], shared[threadIdx.x + offset]);
offset /= 2;
}

__syncthreads();

return shared[0];
}

template <typename T, size_t DimSize, size_t InnerSize, size_t DimThreads>
__global__ void softmax_general(const T* input, T* output, size_t outer_size) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<T*>(smem);
const uint32_t outer_stride = InnerSize * DimSize;
const uint32_t dim_stride = InnerSize;

for (uint32_t outer_index = blockIdx.x; outer_index < outer_size;
outer_index += gridDim.x) {
const uint32_t outer_offset = outer_index * outer_stride;
for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y;
inner_index < InnerSize;
inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index;
T max_input = std::numeric_limits<T>::lowest();
// DimThreads == blockDim.x, but using DimThreads here is actually a perf
// regression
for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x) {
const T value = input[data_offset + d * dim_stride];
max_input = fast_max(max_input, value);
}
if constexpr (DimThreads > 1)
max_input =
softmax_general_block_reduce_x<T, FastMax>(sdata, max_input);

T sum = 0;
for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x)
sum += fast_exp(input[data_offset + d * dim_stride] - max_input);
if constexpr (DimThreads > 1)
sum = softmax_general_block_reduce_x<T, std::plus>(sdata, sum);

for (uint32_t d = threadIdx.x; d < DimSize; d += blockDim.x)
output[data_offset + d * dim_stride] =
fast_exp(input[data_offset + d * dim_stride] - max_input) / sum;
}
}
}

template <size_t InnerThreads, size_t InnerSize>
inline dim3 softmax_general_get_grid_size(
size_t max_active_blocks,
size_t outer_size) {
// First, tile as many blocks as we can over the y axis (block.y ==
// InnerThreads)
size_t inner_blocks = (InnerSize + InnerThreads - 1) / InnerThreads;
if (inner_blocks > max_active_blocks)
inner_blocks = max_active_blocks;
// Fill the x axis with as many blocks as we can fit (a little more is ok too)
size_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
if (outer_blocks > outer_size)
outer_blocks = outer_size;
return dim3(outer_blocks, inner_blocks);
}

// This implementation of softmax can handle arbitrary reduction dimensions, but
// is less efficient than the specialized kernels above that reduce only over
// the last dimension.
template <
typename T,
size_t DimSize,
size_t InnerSize,
size_t DimThreads,
size_t InnerThreads>
void LaunchSoftmaxGeneral(
const T* input,
T* output,
size_t outer_size,
int multiprocessorCount,
cudaStream_t stream) {
int block_size = DimThreads * InnerThreads;
size_t smem_size = DimThreads == 1 ? 0 : block_size * sizeof(T);
int max_active_blocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
softmax_general<T, DimSize, InnerSize, DimThreads>,
block_size,
smem_size);
max_active_blocks *= multiprocessorCount;
dim3 grid = softmax_general_get_grid_size<InnerThreads, InnerSize>(
max_active_blocks, outer_size);
dim3 block(DimThreads, InnerThreads);
softmax_general<T, DimSize, InnerSize, DimThreads>
<<<grid, block, smem_size, stream>>>(input, output, outer_size);
SOFTMAX_LAUNCH_CHECK();
}

#endif
86 changes: 69 additions & 17 deletions python/aitemplate/backend/cuda/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""
Softmax codegen for CUDA.
"""
from __future__ import annotations

import math
import os
from typing import Any, Dict

Expand All @@ -30,7 +32,7 @@
# pylint: disable=C0301, C0116


# input size: [M, K]
# input size: [M, K] where M == outer size (product of outer dims) and K == reduction dim
# We put if else condition here to avoid long compilation time.
# i.e. for each K, we only need to compile one of the implementation, not all.
#
Expand All @@ -44,7 +46,14 @@
{{func_signature}}
{
{{shape_functions}}
size_t m = M;
{{func_impl}}
}
"""
)

FUNC_IMPL_INNER_SIZE_EQ_1 = jinja2.Template(
"""
const size_t M = outer_size;
bool success = true;
// For threshold K, please refer to this post: https://fb.quip.com/HCfIAbpWB0qi
Expand Down Expand Up @@ -100,22 +109,33 @@
LaunchSoftmaxK1Middle<{{dtype}}, {{K}}>(static_cast<const {{dtype}}*>(input), static_cast<{{dtype}}*>(output), M, stream);
{% elif K > 1408 %}
// K > 1408
LaunchSoftmaxBlockAll<{{dtype}}, {{dtype}}, {{K}}>( (const {{dtype}}*) input, ({{dtype}}*) output, m, stream, &success);
LaunchSoftmaxBlockAll<{{dtype}}, {{dtype}}, {{K}}>( (const {{dtype}}*) input, ({{dtype}}*) output, M, stream, &success);
{% endif %}
{% endif %}
if (!success) {
softmaxBlockNocache<{{dtype}}><<<m, 1024, 0, stream>>>(({{dtype}}*)input, ({{dtype}}*)output, m, {{K}});
softmaxBlockNocache<{{dtype}}><<<M, 1024, 0, stream>>>(({{dtype}}*)input, ({{dtype}}*)output, M, {{K}});
}
}
"""
)

FUNC_IMPL_GENERAL = jinja2.Template(
"""
LaunchSoftmaxGeneral<{{dtype}}, {{dim_size}}, {{inner_size}}, {{dim_threads}}, {{inner_threads}}>(
(const {{dtype}}*) input,
({{dtype}}*) output,
outer_size,
multiprocessor_count,
stream
);
"""
)

SHAPE_FUNCTIONS = jinja2.Template(
"""
int64_t M = 1;
int64_t outer_size = 1;
{% for idx in range(reduction_dim) %}
M *= *in_{{idx}};
outer_size *= *in_{{idx}};
{% endfor %}
"""
)
Expand All @@ -127,6 +147,7 @@
{% for idx in range(reduction_dim) %}
int64_t* in_{{idx}},
{% endfor %}
int multiprocessor_count,
cudaStream_t stream)
""",
trim_blocks=True,
Expand All @@ -146,6 +167,7 @@
{% for name in outer_dim_names %}
{{indent}} &{{name}},
{% endfor %}
{{indent}} device_properties_.multiProcessorCount,
{{indent}} stream
{{indent}});
""",
Expand Down Expand Up @@ -175,30 +197,60 @@ def find_tile_size(k: int) -> int:
return m


def _softmax_general_block_size(dim_size: int, inner_size: int) -> tuple[int, int]:
MAX_THREADS_PER_BLOCK = 1024
inner_threads = min(inner_size, MAX_THREADS_PER_BLOCK)
dim_threads = 1
if inner_threads <= 64 and dim_size >= 64:
while (
inner_threads * dim_threads <= MAX_THREADS_PER_BLOCK
and dim_threads <= dim_size
):
dim_threads *= 2
dim_threads //= 2
return dim_threads, inner_threads


@registry.reg("cuda.softmax.gen_function")
def softmax_gen_function(func_attrs: Dict[str, Any]) -> str:
dim = func_attrs["dim"]
shapes = func_attrs["inputs"][0]._attrs["shape"]
reduction_dim = func_attrs["dim"]
shape = func_attrs["inputs"][0]._attrs["shape"]

assert isinstance(
shapes[dim], IntImm
), "softmax requires reduction dim to be static"
assert all(
isinstance(dim, IntImm) for dim in shape[reduction_dim:]
), "softmax requires reduction dim & inner dims to be static"

k = shapes[dim].value()
dim_size = shape[reduction_dim].value()

backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_backend_type(
func_attrs["inputs"][0]._attrs["dtype"]
)

inner_size = math.prod(dim.value() for dim in shape[reduction_dim + 1 :])
if inner_size == 1:
func_impl = FUNC_IMPL_INNER_SIZE_EQ_1.render(
dtype=elem_input_type,
m=find_tile_size(dim_size),
K=dim_size,
)
else:
dim_threads, inner_threads = _softmax_general_block_size(dim_size, inner_size)
func_impl = FUNC_IMPL_GENERAL.render(
dtype=elem_input_type,
dim_size=dim_size,
inner_size=inner_size,
dim_threads=dim_threads,
inner_threads=inner_threads,
)

return FUNC_TEMPLATE.render(
custom_libs=Target.current().get_custom_libs(
os.path.dirname(__file__), "softmax.cuh"
),
func_signature=get_func_signature(func_attrs),
shape_functions=SHAPE_FUNCTIONS.render(reduction_dim=dim),
dtype=elem_input_type,
K=k,
m=find_tile_size(k),
func_impl=func_impl,
shape_functions=SHAPE_FUNCTIONS.render(reduction_dim=reduction_dim),
)


Expand Down
17 changes: 6 additions & 11 deletions python/aitemplate/compiler/ops/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
Tensor,
)
from aitemplate.compiler.ops.softmax.cache_entry import NormQueryEntry, NormRecordEntry
from aitemplate.compiler.ops.tensor.permute import permute

from aitemplate.testing import detect_target

Expand Down Expand Up @@ -205,16 +204,12 @@ def __call__(self, x: Tensor, dim: int = None) -> Tensor:
"flattening input tensor before normalization is not supported yet"
)
dim = wrap_dim(dim, x._rank())
tail_shapes = x.shape()[dim + 1 :]
# The backend only supports reduction over the last non-1 dimension, so if we want
# to reduce over other dimensions we have to permute the tensor first.
if not all(isinstance(s, IntImm) and s.value() == 1 for s in tail_shapes):
perm_shape = list(range(x._rank()))
perm_shape[dim] = x._rank() - 1
perm_shape[-1] = dim
x_perm = permute()(x, perm_shape)
x_perm_softmax = softmax()(x_perm, dim=-1)
return permute()(x_perm_softmax, perm_shape)

inner_dims = x.shape()[dim + 1 :]
if not all(isinstance(d, IntImm) for d in inner_dims):
raise NotImplementedError(
"inner dims must all be static; {dim=}, {x.shape()=}"
)

self._attrs["inputs"] = [x]
self._attrs["dim"] = dim
Expand Down
Loading

0 comments on commit 318111f

Please sign in to comment.