From d2bfbe193e4b21ebc9ed38766638cd50c379116c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Onur=20=C3=9Clgen?= Date: Mon, 22 Jan 2024 15:51:25 +0000 Subject: [PATCH] Implement Cuda::ResampleGradient() #92 --- niftyreg_build_version.txt | 2 +- reg-lib/cuda/CudaResampling.cu | 210 +++++++++++++++++++++++++++++++- reg-lib/cuda/CudaResampling.hpp | 12 ++ 3 files changed, 221 insertions(+), 3 deletions(-) diff --git a/niftyreg_build_version.txt b/niftyreg_build_version.txt index 32890dbd..2c60641d 100644 --- a/niftyreg_build_version.txt +++ b/niftyreg_build_version.txt @@ -1 +1 @@ -387 +388 diff --git a/reg-lib/cuda/CudaResampling.cu b/reg-lib/cuda/CudaResampling.cu index ee2deab5..6cde737d 100644 --- a/reg-lib/cuda/CudaResampling.cu +++ b/reg-lib/cuda/CudaResampling.cu @@ -11,6 +11,7 @@ */ #include "CudaResampling.hpp" +#include "_reg_common_cuda_kernels.cu" /* *************************************************************** */ namespace NiftyReg::Cuda { @@ -78,7 +79,7 @@ void ResampleImage(const nifti_image *floatingImage, auto maskTexturePtr = Cuda::CreateTextureObject(maskCuda, activeVoxelNumber, cudaChannelFormatKindSigned, 1); auto deformationFieldTexture = *deformationFieldTexturePtr; auto maskTexture = *maskTexturePtr; - // Bind the real to voxel matrix to the texture + // Get the real to voxel matrix const mat44& floatingMatrix = floatingImage->sform_code > 0 ? floatingImage->sto_ijk : floatingImage->qto_ijk; for (int t = 0; t < warpedImage->nt * warpedImage->nu; t++) { @@ -166,7 +167,7 @@ void GetImageGradient(const nifti_image *floatingImage, auto deformationFieldTexturePtr = Cuda::CreateTextureObject(deformationFieldCuda, activeVoxelNumber, cudaChannelFormatKindFloat, 4); auto floatingTexture = *floatingTexturePtr; auto deformationFieldTexture = *deformationFieldTexturePtr; - // Bind the real to voxel matrix to the texture + // Get the real to voxel matrix const mat44& floatingMatrix = floatingImage->sform_code > 0 ? floatingImage->sto_ijk : floatingImage->qto_ijk; thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), activeVoxelNumber, [ @@ -232,5 +233,210 @@ void GetImageGradient(const nifti_image *floatingImage, template void GetImageGradient(const nifti_image*, const float*, const float4*, float4*, const size_t, const int, float, const int); template void GetImageGradient(const nifti_image*, const float*, const float4*, float4*, const size_t, const int, float, const int); /* *************************************************************** */ +template +static float3 GetRealImageSpacing(const nifti_image *image) { + float3 spacing{}; + float indexVoxel1[3]{}, indexVoxel2[3], realVoxel1[3], realVoxel2[3]; + reg_mat44_mul(&image->sto_xyz, indexVoxel1, realVoxel1); + + indexVoxel2[1] = indexVoxel2[2] = 0; indexVoxel2[0] = 1; + reg_mat44_mul(&image->sto_xyz, indexVoxel2, realVoxel2); + spacing.x = sqrtf(Square(realVoxel1[0] - realVoxel2[0]) + Square(realVoxel1[1] - realVoxel2[1]) + Square(realVoxel1[2] - realVoxel2[2])); + + indexVoxel2[0] = indexVoxel2[2] = 0; indexVoxel2[1] = 1; + reg_mat44_mul(&image->sto_xyz, indexVoxel2, realVoxel2); + spacing.y = sqrtf(Square(realVoxel1[0] - realVoxel2[0]) + Square(realVoxel1[1] - realVoxel2[1]) + Square(realVoxel1[2] - realVoxel2[2])); + + if constexpr (is3d) { + indexVoxel2[0] = indexVoxel2[1] = 0; indexVoxel2[2] = 1; + reg_mat44_mul(&image->sto_xyz, indexVoxel2, realVoxel2); + spacing.z = sqrtf(Square(realVoxel1[0] - realVoxel2[0]) + Square(realVoxel1[1] - realVoxel2[1]) + Square(realVoxel1[2] - realVoxel2[2])); + } + + return spacing; +} +/* *************************************************************** */ +template struct Gradient { using Type = float3; }; +template<> struct Gradient { using Type = float2; }; +/* *************************************************************** */ +template +void ResampleGradient(const nifti_image *floatingImage, + const float4 *floatingImageCuda, + const nifti_image *warpedImage, + float4 *warpedImageCuda, + const nifti_image *deformationField, + const float4 *deformationFieldCuda, + const int *maskCuda, + const size_t activeVoxelNumber, + const int interpolation, + const float paddingValue) { + if (interpolation != 1) + NR_FATAL_ERROR("Only linear interpolation is supported"); + + const size_t voxelNumber = NiftiImage::calcVoxelNumber(floatingImage, 3); + const int3 floatingDims = make_int3(floatingImage->nx, floatingImage->ny, floatingImage->nz); + const int3 defFieldDims = make_int3(deformationField->nx, deformationField->ny, deformationField->nz); + auto floatingTexturePtr = Cuda::CreateTextureObject(floatingImageCuda, voxelNumber, cudaChannelFormatKindFloat, 4); + auto deformationFieldTexturePtr = Cuda::CreateTextureObject(deformationFieldCuda, activeVoxelNumber, cudaChannelFormatKindFloat, 4); + auto maskTexturePtr = Cuda::CreateTextureObject(maskCuda, activeVoxelNumber, cudaChannelFormatKindSigned, 1); + auto floatingTexture = *floatingTexturePtr; + auto deformationFieldTexture = *deformationFieldTexturePtr; + auto maskTexture = *maskTexturePtr; + + // Get the real to voxel matrix + const mat44& floatingMatrix = floatingImage->sform_code != 0 ? floatingImage->sto_ijk : floatingImage->qto_ijk; + + // The spacing is computed if the sform is defined + const float3 realSpacing = warpedImage->sform_code > 0 ? GetRealImageSpacing(warpedImage) : + make_float3(warpedImage->dx, warpedImage->dy, warpedImage->dz); + + // Reorientation matrix is assessed in order to remove the rigid component + const mat33 reorient = nifti_mat33_inverse(nifti_mat33_polar(reg_mat44_to_mat33(&deformationField->sto_xyz))); + + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), activeVoxelNumber, [ + warpedImageCuda, floatingTexture, deformationFieldTexture, maskTexture, floatingMatrix, floatingDims, defFieldDims, realSpacing, reorient, paddingValue + ]__device__(const int index) { + // Get the real world deformation in the floating space + const int voxel = tex1Dfetch(maskTexture, index); + const float4 realDeformation = tex1Dfetch(deformationFieldTexture, index); + + // Get the voxel-based deformation in the floating space and compute the linear interpolation + int3 previous; + float xBasis[2], yBasis[2], zBasis[2]; + TransformInterpolate(floatingMatrix, realDeformation, previous, xBasis, yBasis, zBasis); + + typename Gradient::Type gradientValue{}; + if constexpr (is3d) { + for (char c = 0; c < 2; c++) { + const int z = previous.z + c; + if (-1 < z && z < floatingDims.z) { + for (char b = 0; b < 2; b++) { + const int y = previous.y + b; + if (-1 < y && y < floatingDims.y) { + for (char a = 0; a < 2; a++) { + const int x = previous.x + a; + const float weight = xBasis[a] * yBasis[b] * zBasis[c]; + if (-1 < x && x < floatingDims.x) { + const int floIndex = (z * floatingDims.y + y) * floatingDims.x + x; + const float3 intensity = make_float3(tex1Dfetch(floatingTexture, floIndex)); + gradientValue = gradientValue + intensity * weight; + } else gradientValue = gradientValue + paddingValue * weight; + } + } else gradientValue = gradientValue + paddingValue * yBasis[b] * zBasis[c]; + } + } else gradientValue = gradientValue + paddingValue * zBasis[c]; + } + } else { + for (char b = 0; b < 2; b++) { + const int y = previous.y + b; + if (-1 < y && y < floatingDims.y) { + for (char a = 0; a < 2; a++) { + const int x = previous.x + a; + const float weight = xBasis[a] * yBasis[b]; + if (-1 < x && x < floatingDims.x) { + const int floIndex = y * floatingDims.x + x; + const float2 intensity = make_float2(tex1Dfetch(floatingTexture, floIndex)); + gradientValue = gradientValue + intensity * weight; + } else gradientValue = gradientValue + paddingValue * weight; + } + } else gradientValue = gradientValue + paddingValue * yBasis[b]; + } + } + + // Compute the Jacobian matrix + constexpr float basis[] = { 1.f, 0.f }; + constexpr float deriv[] = { -1.f, 1.f }; + auto [x, y, z] = reg_indexToDims_cuda(voxel, defFieldDims); + mat33 jacMat{}; + for (char c = 0; c < (is3d ? 2 : 1); c++) { + if constexpr (is3d) { + previous.z = z + c; + zBasis[0] = basis[c]; + zBasis[1] = deriv[c]; + // Boundary conditions along z - slidding + if (z == defFieldDims.z - 1) { + if (c == 1) + previous.z -= 2; + zBasis[0] = fabs(zBasis[0] - 1); + zBasis[1] *= -1; + } + } + for (char b = 0; b < 2; b++) { + previous.y = y + b; + yBasis[0] = basis[b]; + yBasis[1] = deriv[b]; + // Boundary conditions along y - slidding + if (y == defFieldDims.y - 1) { + if (b == 1) + previous.y -= 2; + yBasis[0] = fabs(yBasis[0] - 1); + yBasis[1] *= -1; + } + for (char a = 0; a < 2; a++) { + previous.x = x + a; + xBasis[0] = basis[a]; + xBasis[1] = deriv[a]; + // Boundary conditions along x - slidding + if (x == defFieldDims.x - 1) { + if (a == 1) + previous.x -= 2; + xBasis[0] = fabs(xBasis[0] - 1); + xBasis[1] *= -1; + } + + // Compute the basis function values + const float3 weight = make_float3(xBasis[1] * yBasis[0] * (is3d ? zBasis[0] : 1), + xBasis[0] * yBasis[1] * (is3d ? zBasis[0] : 1), + is3d ? xBasis[0] * yBasis[0] * zBasis[1] : 0); + + // Get the deformation field values + const int defIndex = ((is3d ? previous.z * defFieldDims.y : 0) + previous.y) * defFieldDims.x + previous.x; + const float4 defFieldValue = tex1Dfetch(deformationFieldTexture, defIndex); + + // Symmetric difference to compute the derivatives + jacMat.m[0][0] += weight.x * defFieldValue.x; + jacMat.m[0][1] += weight.y * defFieldValue.x; + jacMat.m[1][0] += weight.x * defFieldValue.y; + jacMat.m[1][1] += weight.y * defFieldValue.y; + if constexpr (is3d) { + jacMat.m[0][2] += weight.z * defFieldValue.x; + jacMat.m[1][2] += weight.z * defFieldValue.y; + jacMat.m[2][0] += weight.x * defFieldValue.z; + jacMat.m[2][1] += weight.y * defFieldValue.z; + jacMat.m[2][2] += weight.z * defFieldValue.z; + } + } + } + } + // reorient and scale the Jacobian matrix + jacMat = reg_mat33_mul_cuda(reorient, jacMat); + jacMat.m[0][0] /= realSpacing.x; + jacMat.m[0][1] /= realSpacing.y; + jacMat.m[1][0] /= realSpacing.x; + jacMat.m[1][1] /= realSpacing.y; + if constexpr (is3d) { + jacMat.m[0][2] /= realSpacing.z; + jacMat.m[1][2] /= realSpacing.z; + jacMat.m[2][0] /= realSpacing.x; + jacMat.m[2][1] /= realSpacing.y; + jacMat.m[2][2] /= realSpacing.z; + } + + // Modulate the gradient scalar values + float4 warpedValue{}; + if constexpr (is3d) { + warpedValue.x = jacMat.m[0][0] * gradientValue.x + jacMat.m[0][1] * gradientValue.y + jacMat.m[0][2] * gradientValue.z; + warpedValue.y = jacMat.m[1][0] * gradientValue.x + jacMat.m[1][1] * gradientValue.y + jacMat.m[1][2] * gradientValue.z; + warpedValue.z = jacMat.m[2][0] * gradientValue.x + jacMat.m[2][1] * gradientValue.y + jacMat.m[2][2] * gradientValue.z; + } else { + warpedValue.x = jacMat.m[0][0] * gradientValue.x + jacMat.m[0][1] * gradientValue.y; + warpedValue.y = jacMat.m[1][0] * gradientValue.x + jacMat.m[1][1] * gradientValue.y; + } + warpedImageCuda[voxel] = warpedValue; + }); +} +template void ResampleGradient(const nifti_image*, const float4*, const nifti_image*, float4*, const nifti_image*, const float4*, const int*, const size_t, const int, const float); +template void ResampleGradient(const nifti_image*, const float4*, const nifti_image*, float4*, const nifti_image*, const float4*, const int*, const size_t, const int, const float); +/* *************************************************************** */ } // namespace NiftyReg::Cuda /* *************************************************************** */ diff --git a/reg-lib/cuda/CudaResampling.hpp b/reg-lib/cuda/CudaResampling.hpp index 1366ccc7..7f6bbac8 100644 --- a/reg-lib/cuda/CudaResampling.hpp +++ b/reg-lib/cuda/CudaResampling.hpp @@ -38,5 +38,17 @@ void GetImageGradient(const nifti_image *floatingImage, float paddingValue, const int activeTimePoint); /* *************************************************************** */ +template +void ResampleGradient(const nifti_image *floatingImage, + const float4 *floatingImageCuda, + const nifti_image *warpedImage, + float4 *warpedImageCuda, + const nifti_image *deformationField, + const float4 *deformationFieldCuda, + const int *maskCuda, + const size_t activeVoxelNumber, + const int interpolation, + const float paddingValue); +/* *************************************************************** */ } // namespace NiftyReg::Cuda /* *************************************************************** */