Skip to content

Commit

Permalink
add AMD support
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Dec 19, 2024
1 parent 0ad691f commit d9c6fe5
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions common/cuda_hip/factorization/cholesky_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ __global__ __launch_bounds__(default_block_size) void mst_initialize_worklist(
IndexType* __restrict__ worklist_edge_ids,
IndexType* __restrict__ worklist_counter)
{
using atomic_type = std::conditional_t<std::is_same_v<IndexType, int32>,
int32, unsigned long long>;
const auto i = thread::get_thread_id_flat<IndexType>();
if (i >= size) {
return;
}
const auto row = rows[i];
const auto col = cols[i];
if (col < row) {
const auto out_i = atomic_add_relaxed(worklist_counter, 1);
const auto out_i = static_cast<IndexType>(atomicAdd(
reinterpret_cast<atomic_type*>(worklist_counter), atomic_type{1}));
worklist_sources[out_i] = row;
worklist_targets[out_i] = col;
worklist_edge_ids[out_i] = i;
Expand Down Expand Up @@ -100,9 +103,12 @@ __device__ IndexType mst_find_relaxed(const IndexType* parents, IndexType node)
template <typename IndexType>
__device__ void guarded_atomic_min(IndexType* ptr, IndexType value)
{
using atomic_type = std::conditional_t<std::is_same_v<IndexType, int32>,
int32, unsigned long long>;
// only execute the atomic if we know that it might have an effect
if (load_relaxed_local(ptr) > value) {
atomic_min_relaxed(ptr, value);
atomicMin(reinterpret_cast<atomic_type*>(ptr),
static_cast<atomic_type>(value));
}
}

Expand All @@ -118,6 +124,8 @@ __global__ __launch_bounds__(default_block_size) void mst_find_minimum(
IndexType* __restrict__ worklist_edge_ids,
IndexType* __restrict__ worklist_counter)
{
using atomic_type = std::conditional_t<std::is_same_v<IndexType, int32>,
int32, unsigned long long>;
const auto i = thread::get_thread_id_flat<IndexType>();
if (i >= size) {
return;
Expand All @@ -128,7 +136,8 @@ __global__ __launch_bounds__(default_block_size) void mst_find_minimum(
const auto source_rep = mst_find(parents, source);
const auto target_rep = mst_find(parents, target);
if (source_rep != target_rep) {
const auto out_i = atomic_add_relaxed(worklist_counter, 1);
const auto out_i = static_cast<IndexType>(atomicAdd(
reinterpret_cast<atomic_type*>(worklist_counter), atomic_type{1}));
worklist_sources[out_i] = source_rep;
worklist_targets[out_i] = target_rep;
worklist_edge_ids[out_i] = edge_id;
Expand All @@ -149,6 +158,8 @@ __global__ __launch_bounds__(default_block_size) void mst_join_edges(
IndexType* __restrict__ out_sources, IndexType* __restrict__ out_targets,
IndexType* __restrict__ out_counter)
{
using atomic_type = std::conditional_t<std::is_same_v<IndexType, int32>,
int32, unsigned long long>;
const auto i = thread::get_thread_id_flat<IndexType>();
if (i >= size) {
return;
Expand All @@ -166,9 +177,6 @@ __global__ __launch_bounds__(default_block_size) void mst_join_edges(
bool repeat = false;
do {
repeat = false;
using atomic_type =
std::conditional_t<std::is_same_v<IndexType, int32>, int32,
unsigned long long>;
auto old_parent =
atomicCAS(reinterpret_cast<atomic_type*>(parents + old_rep),
static_cast<atomic_type>(old_rep),
Expand All @@ -180,7 +188,8 @@ __global__ __launch_bounds__(default_block_size) void mst_join_edges(
repeat = true;
}
} while (repeat);
const auto out_i = atomic_add_relaxed(out_counter, 1);
const auto out_i = static_cast<IndexType>(atomicAdd(
reinterpret_cast<atomic_type*>(out_counter), atomic_type{1}));
out_sources[out_i] = edge_sources[edge_id];
out_targets[out_i] = edge_targets[edge_id];
}
Expand Down

0 comments on commit d9c6fe5

Please sign in to comment.