Skip to content

Commit

Permalink
Implement Cuda::ResampleGradient() #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Jan 22, 2024
1 parent cbdea7c commit d2bfbe1
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 3 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
387
388
210 changes: 208 additions & 2 deletions reg-lib/cuda/CudaResampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
*/

#include "CudaResampling.hpp"
#include "_reg_common_cuda_kernels.cu"

/* *************************************************************** */
namespace NiftyReg::Cuda {
Expand Down Expand Up @@ -78,7 +79,7 @@ void ResampleImage(const nifti_image *floatingImage,
auto maskTexturePtr = Cuda::CreateTextureObject(maskCuda, activeVoxelNumber, cudaChannelFormatKindSigned, 1);
auto deformationFieldTexture = *deformationFieldTexturePtr;
auto maskTexture = *maskTexturePtr;
// Bind the real to voxel matrix to the texture
// Get the real to voxel matrix
const mat44& floatingMatrix = floatingImage->sform_code > 0 ? floatingImage->sto_ijk : floatingImage->qto_ijk;

for (int t = 0; t < warpedImage->nt * warpedImage->nu; t++) {
Expand Down Expand Up @@ -166,7 +167,7 @@ void GetImageGradient(const nifti_image *floatingImage,
auto deformationFieldTexturePtr = Cuda::CreateTextureObject(deformationFieldCuda, activeVoxelNumber, cudaChannelFormatKindFloat, 4);
auto floatingTexture = *floatingTexturePtr;
auto deformationFieldTexture = *deformationFieldTexturePtr;
// Bind the real to voxel matrix to the texture
// Get the real to voxel matrix
const mat44& floatingMatrix = floatingImage->sform_code > 0 ? floatingImage->sto_ijk : floatingImage->qto_ijk;

thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), activeVoxelNumber, [
Expand Down Expand Up @@ -232,5 +233,210 @@ void GetImageGradient(const nifti_image *floatingImage,
template void GetImageGradient<false>(const nifti_image*, const float*, const float4*, float4*, const size_t, const int, float, const int);
template void GetImageGradient<true>(const nifti_image*, const float*, const float4*, float4*, const size_t, const int, float, const int);
/* *************************************************************** */
template<bool is3d>
static float3 GetRealImageSpacing(const nifti_image *image) {
float3 spacing{};
float indexVoxel1[3]{}, indexVoxel2[3], realVoxel1[3], realVoxel2[3];
reg_mat44_mul(&image->sto_xyz, indexVoxel1, realVoxel1);

indexVoxel2[1] = indexVoxel2[2] = 0; indexVoxel2[0] = 1;
reg_mat44_mul(&image->sto_xyz, indexVoxel2, realVoxel2);
spacing.x = sqrtf(Square(realVoxel1[0] - realVoxel2[0]) + Square(realVoxel1[1] - realVoxel2[1]) + Square(realVoxel1[2] - realVoxel2[2]));

indexVoxel2[0] = indexVoxel2[2] = 0; indexVoxel2[1] = 1;
reg_mat44_mul(&image->sto_xyz, indexVoxel2, realVoxel2);
spacing.y = sqrtf(Square(realVoxel1[0] - realVoxel2[0]) + Square(realVoxel1[1] - realVoxel2[1]) + Square(realVoxel1[2] - realVoxel2[2]));

if constexpr (is3d) {
indexVoxel2[0] = indexVoxel2[1] = 0; indexVoxel2[2] = 1;
reg_mat44_mul(&image->sto_xyz, indexVoxel2, realVoxel2);
spacing.z = sqrtf(Square(realVoxel1[0] - realVoxel2[0]) + Square(realVoxel1[1] - realVoxel2[1]) + Square(realVoxel1[2] - realVoxel2[2]));
}

return spacing;
}
/* *************************************************************** */
template<bool is3d> struct Gradient { using Type = float3; };
template<> struct Gradient<false> { using Type = float2; };
/* *************************************************************** */
template<bool is3d>
void ResampleGradient(const nifti_image *floatingImage,
const float4 *floatingImageCuda,
const nifti_image *warpedImage,
float4 *warpedImageCuda,
const nifti_image *deformationField,
const float4 *deformationFieldCuda,
const int *maskCuda,
const size_t activeVoxelNumber,
const int interpolation,
const float paddingValue) {
if (interpolation != 1)
NR_FATAL_ERROR("Only linear interpolation is supported");

const size_t voxelNumber = NiftiImage::calcVoxelNumber(floatingImage, 3);
const int3 floatingDims = make_int3(floatingImage->nx, floatingImage->ny, floatingImage->nz);
const int3 defFieldDims = make_int3(deformationField->nx, deformationField->ny, deformationField->nz);
auto floatingTexturePtr = Cuda::CreateTextureObject(floatingImageCuda, voxelNumber, cudaChannelFormatKindFloat, 4);
auto deformationFieldTexturePtr = Cuda::CreateTextureObject(deformationFieldCuda, activeVoxelNumber, cudaChannelFormatKindFloat, 4);
auto maskTexturePtr = Cuda::CreateTextureObject(maskCuda, activeVoxelNumber, cudaChannelFormatKindSigned, 1);
auto floatingTexture = *floatingTexturePtr;
auto deformationFieldTexture = *deformationFieldTexturePtr;
auto maskTexture = *maskTexturePtr;

// Get the real to voxel matrix
const mat44& floatingMatrix = floatingImage->sform_code != 0 ? floatingImage->sto_ijk : floatingImage->qto_ijk;

// The spacing is computed if the sform is defined
const float3 realSpacing = warpedImage->sform_code > 0 ? GetRealImageSpacing<is3d>(warpedImage) :
make_float3(warpedImage->dx, warpedImage->dy, warpedImage->dz);

// Reorientation matrix is assessed in order to remove the rigid component
const mat33 reorient = nifti_mat33_inverse(nifti_mat33_polar(reg_mat44_to_mat33(&deformationField->sto_xyz)));

thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), activeVoxelNumber, [
warpedImageCuda, floatingTexture, deformationFieldTexture, maskTexture, floatingMatrix, floatingDims, defFieldDims, realSpacing, reorient, paddingValue
]__device__(const int index) {
// Get the real world deformation in the floating space
const int voxel = tex1Dfetch<int>(maskTexture, index);
const float4 realDeformation = tex1Dfetch<float4>(deformationFieldTexture, index);

// Get the voxel-based deformation in the floating space and compute the linear interpolation
int3 previous;
float xBasis[2], yBasis[2], zBasis[2];
TransformInterpolate<float, is3d>(floatingMatrix, realDeformation, previous, xBasis, yBasis, zBasis);

typename Gradient<is3d>::Type gradientValue{};
if constexpr (is3d) {
for (char c = 0; c < 2; c++) {
const int z = previous.z + c;
if (-1 < z && z < floatingDims.z) {
for (char b = 0; b < 2; b++) {
const int y = previous.y + b;
if (-1 < y && y < floatingDims.y) {
for (char a = 0; a < 2; a++) {
const int x = previous.x + a;
const float weight = xBasis[a] * yBasis[b] * zBasis[c];
if (-1 < x && x < floatingDims.x) {
const int floIndex = (z * floatingDims.y + y) * floatingDims.x + x;
const float3 intensity = make_float3(tex1Dfetch<float4>(floatingTexture, floIndex));
gradientValue = gradientValue + intensity * weight;
} else gradientValue = gradientValue + paddingValue * weight;
}
} else gradientValue = gradientValue + paddingValue * yBasis[b] * zBasis[c];
}
} else gradientValue = gradientValue + paddingValue * zBasis[c];
}
} else {
for (char b = 0; b < 2; b++) {
const int y = previous.y + b;
if (-1 < y && y < floatingDims.y) {
for (char a = 0; a < 2; a++) {
const int x = previous.x + a;
const float weight = xBasis[a] * yBasis[b];
if (-1 < x && x < floatingDims.x) {
const int floIndex = y * floatingDims.x + x;
const float2 intensity = make_float2(tex1Dfetch<float4>(floatingTexture, floIndex));
gradientValue = gradientValue + intensity * weight;
} else gradientValue = gradientValue + paddingValue * weight;
}
} else gradientValue = gradientValue + paddingValue * yBasis[b];
}
}

// Compute the Jacobian matrix
constexpr float basis[] = { 1.f, 0.f };
constexpr float deriv[] = { -1.f, 1.f };
auto [x, y, z] = reg_indexToDims_cuda<is3d>(voxel, defFieldDims);
mat33 jacMat{};
for (char c = 0; c < (is3d ? 2 : 1); c++) {
if constexpr (is3d) {
previous.z = z + c;
zBasis[0] = basis[c];
zBasis[1] = deriv[c];
// Boundary conditions along z - slidding
if (z == defFieldDims.z - 1) {
if (c == 1)
previous.z -= 2;
zBasis[0] = fabs(zBasis[0] - 1);
zBasis[1] *= -1;
}
}
for (char b = 0; b < 2; b++) {
previous.y = y + b;
yBasis[0] = basis[b];
yBasis[1] = deriv[b];
// Boundary conditions along y - slidding
if (y == defFieldDims.y - 1) {
if (b == 1)
previous.y -= 2;
yBasis[0] = fabs(yBasis[0] - 1);
yBasis[1] *= -1;
}
for (char a = 0; a < 2; a++) {
previous.x = x + a;
xBasis[0] = basis[a];
xBasis[1] = deriv[a];
// Boundary conditions along x - slidding
if (x == defFieldDims.x - 1) {
if (a == 1)
previous.x -= 2;
xBasis[0] = fabs(xBasis[0] - 1);
xBasis[1] *= -1;
}

// Compute the basis function values
const float3 weight = make_float3(xBasis[1] * yBasis[0] * (is3d ? zBasis[0] : 1),
xBasis[0] * yBasis[1] * (is3d ? zBasis[0] : 1),
is3d ? xBasis[0] * yBasis[0] * zBasis[1] : 0);

// Get the deformation field values
const int defIndex = ((is3d ? previous.z * defFieldDims.y : 0) + previous.y) * defFieldDims.x + previous.x;
const float4 defFieldValue = tex1Dfetch<float4>(deformationFieldTexture, defIndex);

// Symmetric difference to compute the derivatives
jacMat.m[0][0] += weight.x * defFieldValue.x;
jacMat.m[0][1] += weight.y * defFieldValue.x;
jacMat.m[1][0] += weight.x * defFieldValue.y;
jacMat.m[1][1] += weight.y * defFieldValue.y;
if constexpr (is3d) {
jacMat.m[0][2] += weight.z * defFieldValue.x;
jacMat.m[1][2] += weight.z * defFieldValue.y;
jacMat.m[2][0] += weight.x * defFieldValue.z;
jacMat.m[2][1] += weight.y * defFieldValue.z;
jacMat.m[2][2] += weight.z * defFieldValue.z;
}
}
}
}
// reorient and scale the Jacobian matrix
jacMat = reg_mat33_mul_cuda(reorient, jacMat);
jacMat.m[0][0] /= realSpacing.x;
jacMat.m[0][1] /= realSpacing.y;
jacMat.m[1][0] /= realSpacing.x;
jacMat.m[1][1] /= realSpacing.y;
if constexpr (is3d) {
jacMat.m[0][2] /= realSpacing.z;
jacMat.m[1][2] /= realSpacing.z;
jacMat.m[2][0] /= realSpacing.x;
jacMat.m[2][1] /= realSpacing.y;
jacMat.m[2][2] /= realSpacing.z;
}

// Modulate the gradient scalar values
float4 warpedValue{};
if constexpr (is3d) {
warpedValue.x = jacMat.m[0][0] * gradientValue.x + jacMat.m[0][1] * gradientValue.y + jacMat.m[0][2] * gradientValue.z;
warpedValue.y = jacMat.m[1][0] * gradientValue.x + jacMat.m[1][1] * gradientValue.y + jacMat.m[1][2] * gradientValue.z;
warpedValue.z = jacMat.m[2][0] * gradientValue.x + jacMat.m[2][1] * gradientValue.y + jacMat.m[2][2] * gradientValue.z;
} else {
warpedValue.x = jacMat.m[0][0] * gradientValue.x + jacMat.m[0][1] * gradientValue.y;
warpedValue.y = jacMat.m[1][0] * gradientValue.x + jacMat.m[1][1] * gradientValue.y;
}
warpedImageCuda[voxel] = warpedValue;
});
}
template void ResampleGradient<false>(const nifti_image*, const float4*, const nifti_image*, float4*, const nifti_image*, const float4*, const int*, const size_t, const int, const float);
template void ResampleGradient<true>(const nifti_image*, const float4*, const nifti_image*, float4*, const nifti_image*, const float4*, const int*, const size_t, const int, const float);
/* *************************************************************** */
} // namespace NiftyReg::Cuda
/* *************************************************************** */
12 changes: 12 additions & 0 deletions reg-lib/cuda/CudaResampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,17 @@ void GetImageGradient(const nifti_image *floatingImage,
float paddingValue,
const int activeTimePoint);
/* *************************************************************** */
template<bool is3d>
void ResampleGradient(const nifti_image *floatingImage,
const float4 *floatingImageCuda,
const nifti_image *warpedImage,
float4 *warpedImageCuda,
const nifti_image *deformationField,
const float4 *deformationFieldCuda,
const int *maskCuda,
const size_t activeVoxelNumber,
const int interpolation,
const float paddingValue);
/* *************************************************************** */
} // namespace NiftyReg::Cuda
/* *************************************************************** */

0 comments on commit d2bfbe1

Please sign in to comment.