Skip to content

Commit

Permalink
Implement CudaCompute::UpdateVelocityField() #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Jan 16, 2024
1 parent 540f10b commit b34de37
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 6 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
384
385
25 changes: 20 additions & 5 deletions reg-lib/cuda/CudaCompute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,27 @@ void CudaCompute::ExponentiateGradient(Content& conBwIn) {
dynamic_cast<CudaDefContent&>(con).UpdateVoxelBasedMeasureGradient();
}
/* *************************************************************** */
Cuda::UniquePtr<float4> CudaCompute::ScaleGradient(const float4 *transGradCuda, const size_t voxelNumber, const float scale) {
float4 *scaledGradient;
Cuda::Allocate(&scaledGradient, voxelNumber);
Cuda::MultiplyValue(voxelNumber, transGradCuda, scaledGradient, scale);
return Cuda::UniquePtr<float4>(scaledGradient);
}
/* *************************************************************** */
void CudaCompute::UpdateVelocityField(float scale, bool optimiseX, bool optimiseY, bool optimiseZ) {
// TODO Implement this for CUDA
// Use CPU temporarily
Compute::UpdateVelocityField(scale, optimiseX, optimiseY, optimiseZ);
// Transfer the data back to the CUDA device
dynamic_cast<CudaF3dContent&>(con).UpdateControlPointGrid();
if (!optimiseX && !optimiseY && !optimiseZ) return;

CudaF3dContent& con = dynamic_cast<CudaF3dContent&>(this->con);
const nifti_image *controlPointGrid = con.F3dContent::GetControlPointGrid();
const size_t voxelNumber = NiftiImage::calcVoxelNumber(controlPointGrid, 3);
auto scaledGradientCudaPtr = ScaleGradient(con.GetTransformationGradientCuda(), voxelNumber, scale);

// Reset the gradient along the axes if appropriate
if (controlPointGrid->nu < 3) optimiseZ = true;
Cuda::SetGradientToZero(scaledGradientCudaPtr.get(), voxelNumber, !optimiseX, !optimiseY, !optimiseZ);

// Update the velocity field
Cuda::AddImages(controlPointGrid, con.GetControlPointGridCuda(), scaledGradientCudaPtr.get());
}
/* *************************************************************** */
void CudaCompute::BchUpdate(float scale, int bchUpdateValue) {
Expand Down
1 change: 1 addition & 0 deletions reg-lib/cuda/CudaCompute.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ class CudaCompute: public Compute {

private:
void ConvolveImage(const nifti_image*, float4*);
Cuda::UniquePtr<float4> ScaleGradient(const float4*, const size_t, const float);
};
39 changes: 39 additions & 0 deletions reg-lib/cuda/CudaTools.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ void MultiplyValue(const size_t count, float4 *arrayCuda, const float multiplier
});
}
/* *************************************************************** */
void MultiplyValue(const size_t count, const float4 *arrayCuda, float4 *arrayOutCuda, const float multiplier) {
auto arrayTexturePtr = Cuda::CreateTextureObject(arrayCuda, count, cudaChannelFormatKindFloat, 4);
auto arrayTexture = *arrayTexturePtr;
thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), count, [=]__device__(const int index) {
float4 val = tex1Dfetch<float4>(arrayTexture, index);
arrayOutCuda[index] = val * multiplier;
});
}
/* *************************************************************** */
float SumReduction(float *arrayCuda, const size_t size) {
thrust::device_ptr<float> dptr(arrayCuda);
return thrust::reduce(thrust::device, dptr, dptr + size, 0.f, thrust::plus<float>());
Expand Down Expand Up @@ -367,5 +376,35 @@ float GetMaxValue(const nifti_image *img, const float4 *imgCuda, const int timeP
return GetMinMaxValue<false>(img, imgCuda, timePoint);
}
/* *************************************************************** */
template<bool xAxis, bool yAxis, bool zAxis>
void SetGradientToZero(float4 *gradCuda, const size_t voxelNumber) {
auto gradTexturePtr = Cuda::CreateTextureObject(gradCuda, voxelNumber, cudaChannelFormatKindFloat, 4);
auto gradTexture = *gradTexturePtr;
thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), voxelNumber, [gradCuda, gradTexture]__device__(const int index) {
if constexpr (xAxis && yAxis && zAxis) {
gradCuda[index] = make_float4(0.f, 0.f, 0.f, 0.f);
} else {
float4 val = tex1Dfetch<float4>(gradTexture, index);
if constexpr (xAxis) val.x = 0;
if constexpr (yAxis) val.y = 0;
if constexpr (zAxis) val.z = 0;
gradCuda[index] = val;
}
});
}
/* *************************************************************** */
void SetGradientToZero(float4 *gradCuda, const size_t voxelNumber, const bool xAxis, const bool yAxis, const bool zAxis) {
if (!xAxis && !yAxis && !zAxis) return;
decltype(SetGradientToZero<true, true, true>) *setGradientToZero;
if (xAxis && yAxis && zAxis) setGradientToZero = SetGradientToZero<true, true, true>;
else if (xAxis && yAxis) setGradientToZero = SetGradientToZero<true, true, false>;
else if (xAxis && zAxis) setGradientToZero = SetGradientToZero<true, false, true>;
else if (yAxis && zAxis) setGradientToZero = SetGradientToZero<false, true, true>;
else if (xAxis) setGradientToZero = SetGradientToZero<true, false, false>;
else if (yAxis) setGradientToZero = SetGradientToZero<false, true, false>;
else if (zAxis) setGradientToZero = SetGradientToZero<false, false, true>;
setGradientToZero(gradCuda, voxelNumber);
}
/* *************************************************************** */
} // namespace NiftyReg::Cuda
/* *************************************************************** */
8 changes: 8 additions & 0 deletions reg-lib/cuda/CudaTools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ void AddValue(const size_t count, float4 *arrayCuda, const float value);
/* *************************************************************** */
void MultiplyValue(const size_t count, float4 *arrayCuda, const float value);
/* *************************************************************** */
void MultiplyValue(const size_t count, const float4 *arrayCuda, float4 *arrayOutCuda, const float value);
/* *************************************************************** */
float SumReduction(float *arrayCuda, const size_t size);
/* *************************************************************** */
float MaxReduction(float *arrayCuda, const size_t size);
Expand All @@ -61,5 +63,11 @@ float GetMinValue(const nifti_image *img, const float4 *imgCuda, const int timeP
/* *************************************************************** */
float GetMaxValue(const nifti_image *img, const float4 *imgCuda, const int timePoint = -1);
/* *************************************************************** */
void SetGradientToZero(float4 *gradCuda,
const size_t voxelNumber,
const bool xAxis,
const bool yAxis,
const bool zAxis);
/* *************************************************************** */
} // namespace NiftyReg::Cuda
/* *************************************************************** */

0 comments on commit b34de37

Please sign in to comment.