Skip to content

Commit

Permalink
Refactor stochastic rounding state in TBE (pytorch#2312)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2312

This diff addresses an issue with `StochasticRoundingRNGState` where
it was previously allocated inside a function but its address was
accessed after the function had returned, leading to illegal memory
access.  To address this, the allocation of
`StochasticRoundingRNGState` has been moved outside of the function to
ensure that it remains alive for all accesses, preventing any illegal
memory access issues.

Reviewed By: jspark1105

Differential Revision: D53462989

fbshipit-source-id: 9b962bcdc901f6ff62388c2a02ec6ea3068844fe
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 9, 2024
1 parent 86ea895 commit 8664c84
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value) {
qparams = weight_row.load_qparams();
Expand All @@ -174,8 +173,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value) {
qparams = weight_row.load_qparams();
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/codegen/embedding_common_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def lamb() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> weight_sum_sq = 0.0;
at::acc_type<cache_t, true> rtw_sum_sq = 0.0;
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
qparams = weight_row.load_qparams();
Expand Down Expand Up @@ -1187,7 +1187,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]:
at::acc_type<cache_t, true> weight_sum_sq = 0.0;
at::acc_type<cache_t, true> rtw_sum_sq = 0.0;
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
qparams = weight_row.load_qparams();
Expand Down Expand Up @@ -1375,7 +1375,7 @@ def lars_sgd() -> Dict[str, Any]:
at::acc_type<cache_t, true> weight_sum_sq = 0.0;
at::acc_type<cache_t, true> grad_sum_sq = 0.0;
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D, nullptr);
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
qparams = weight_row.load_qparams();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
D);
[[maybe_unused]] float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
Expand All @@ -166,8 +165,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
D);
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,16 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel(
struct SharedMemory<Vec4T<at::acc_type<cache_t, true>>> weight_update_buffer;
Vec4T<at::acc_type<cache_t, true>>* shared_weight_update_row =
is_int8 ? weight_update_buffer.getPointer() : nullptr;

StochasticRoundingRNGState state;
auto weight_row_template =
WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
weights, cache_weights, D, nullptr);

weight_row_template.set_stochastic_rounding(
stochastic_rounding,
stochastic_rounding_philox_args,
threadIdx.x + run_id * blockDim.x
);
weights,
cache_weights,
D,
stochastic_rounding ? &state : nullptr,
&stochastic_rounding_philox_args,
threadIdx.x + run_id * blockDim.x);

float2 qparams_template;
if (is_int8 && !cache_weights) {
Expand Down
60 changes: 31 additions & 29 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1334,15 +1334,41 @@ template <typename emb_t, typename cache_t, typename dst_t>
// TODO: pass in dimension info and calculate qparams for rowwise integer
// quantization
struct WeightRow {
// Constructor for no stochastic rounding
DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim)
: row_(row),
cache_row_(cache_row),
dim_(dim),
stoc_rounding_state_(nullptr) {}

// Constructor for stochastic rounding
DEVICE_INLINE WeightRow(
emb_t* row,
cache_t* cache_row,
int dim,
StochasticRoundingRNGState* stoc_rounding_state)
: row_(row),
cache_row_(cache_row),
dim_(dim),
stoc_rounding_state_(stoc_rounding_state) {}
StochasticRoundingRNGState* stoc_rounding_state,
const at::PhiloxCudaState* stochastic_rounding_philox_args,
const uint64_t salt_value)
: row_(row), cache_row_(cache_row), dim_(dim) {
// Set the internal stoc_rounding_state_
stoc_rounding_state_ = stoc_rounding_state;

if constexpr (!std::is_same_v<emb_t, float>) {
if (stoc_rounding_state != nullptr) {
const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(*stochastic_rounding_philox_args);

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
stoc_rounding_state);
}
}
}

emb_t* row_;
cache_t* cache_row_;
int dim_;
Expand Down Expand Up @@ -1466,30 +1492,6 @@ struct WeightRow {
evict_cache(d, qparams);
}
}

DEVICE_INLINE void set_stochastic_rounding(
const bool stochastic_rounding,
const at::PhiloxCudaState stochastic_rounding_philox_args,
const uint64_t salt_value) {
if constexpr (!std::is_same_v<emb_t, float>) {
if (stochastic_rounding) {
StochasticRoundingRNGState state;
const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&state);

// Set the internal stoc_rounding_state_
stoc_rounding_state_ = &state;
}
}
}
};

template <typename emb_t, typename cache_t, typename dst_t, bool uses_cache>
Expand Down
11 changes: 4 additions & 7 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,13 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(
if constexpr (std::is_same_v<emb_t, uint8_t>) {
D_emb += kINT8QparamsBytes;
}
StochasticRoundingRNGState state;
auto weight_row = WeightRow<emb_t, cache_t, cache_t>(
&weights[weights_offset_current + idx_current * D_emb + 0],
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0],
D_current,
nullptr);

weight_row.set_stochastic_rounding(
stochastic_rounding,
stochastic_rounding_philox_args,
stochastic_rounding ? &state : nullptr,
&stochastic_rounding_philox_args,
(blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x) *
kWarpSize +
Expand All @@ -142,8 +140,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
&weights[weights_offset_insert + idx_insert * D_emb + 0],
nullptr,
D_insert,
nullptr);
D_insert);

weight_row_emb.warp_copy_to_cache(
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0],
Expand Down
19 changes: 8 additions & 11 deletions fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const int32_t sorted_slot = slots[0];
const int64_t sorted_lru_cost = costs[0];
const auto stoc_rounding_salt = kWarpSize *
(blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x);

for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
const int32_t insert_slot = shfl_sync(sorted_slot, l);
Expand Down Expand Up @@ -120,19 +123,14 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
D_emb += kINT8QparamsBytes;
}

StochasticRoundingRNGState state;
auto weight_row = WeightRow<emb_t, cache_t, cache_t>(
&weights[weights_offset_current + idx_current * D_emb + 0],
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0],
D_current,
nullptr);

weight_row.set_stochastic_rounding(
stochastic_rounding,
stochastic_rounding_philox_args,
(blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x) *
kWarpSize +
l);
stochastic_rounding ? &state : nullptr,
&stochastic_rounding_philox_args,
stoc_rounding_salt + l);

weight_row.warp_evict_cache(D_current, blockDim.x, threadIdx.x);
}
Expand All @@ -145,8 +143,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
&weights[weights_offset_insert + idx_insert * D_emb + 0],
nullptr,
D_insert,
nullptr);
D_insert);

weight_row_emb.warp_copy_to_cache(
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0],
Expand Down
8 changes: 3 additions & 5 deletions fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,13 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
if constexpr (std::is_same_v<emb_t, uint8_t>) {
D_emb += kINT8QparamsBytes;
}
StochasticRoundingRNGState state;
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
&weights[weights_offset_current + idx_current * D_emb + 0],
&lxu_cache_weights[b][0],
D_current,
nullptr);

weight_row.set_stochastic_rounding(
stochastic_rounding,
stochastic_rounding_philox_args,
stochastic_rounding ? &state : nullptr,
&stochastic_rounding_philox_args,
blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ __global__ __launch_bounds__(kMaxThreads) void reset_weight_momentum_kernel(

auto weight_row_template =
WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
weights, cache_weights, D, nullptr);
weights, cache_weights, D);

// reset momentum1
const int32_t d = (i % chunk4s_per_row) * 4;
Expand Down

0 comments on commit 8664c84

Please sign in to comment.