Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to change the roi pooling to 3D? #10

Open
lihaolin88 opened this issue Oct 12, 2023 · 0 comments
Open

Is it possible to change the roi pooling to 3D? #10

lihaolin88 opened this issue Oct 12, 2023 · 0 comments

Comments

@lihaolin88
Copy link

lihaolin88 commented Oct 12, 2023

Hello, I modified the roi_pool_kernel.cu and made it accept 3D input, but I'm not very familiar with the cuda code, can anyone helps me to check if I made some mistakes? Very appreciate!

My input shape is (B, C, H, W, D), roi shape is (num_of_roi, 7) (the order of roi shape is (label, min_width, min_depth, min_height, max_width, max_depth, max_height))
And the output I expect is: (num_of_roi, C, pool_size, pool_size, pool_size)
(I don't know why GitHub break my code to multiple parts, sorry for the inconvenient)

`
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#include "cuda_helpers.h"

template
global void RoIPoolForward(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int depth,
const int pooled_height,
const int pooled_width,
const int pooled_depth,
const T* rois,
T* output,
int* argmax_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int pd = (index / (pooled_widthpooled_height)) % pooled_depth;
int c = (index / (pooled_width * pooled_height
pooled_depth)) % channels;
int n = index / (pooled_width * pooled_height* pooled_depth * channels);

const T* offset_rois = rois + n * 7;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * (284/62));//spatial_scale);  //for spatial need to change
int roi_start_h = round(offset_rois[3] * (266/60));//spatial_scale);  //different side need different number
int roi_start_d = round(offset_rois[2] * (316/124));//spatial_scale);
int roi_end_w = round(offset_rois[4] * (284/62));//spatial_scale);
int roi_end_h = round(offset_rois[6] * (266/60));//spatial_scale);
int roi_end_d = round(offset_rois[5] * (316/124));//spatial_scale);

// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
int roi_depth = max(roi_end_d - roi_start_d + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);

int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int dstart = static_cast<int>(floor(static_cast<T>(pd) * bin_size_d));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
int dend = static_cast<int>(ceil(static_cast<T>(pd + 1) * bin_size_d));

// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
dstart = min(max(dstart + roi_start_d, 0), depth);
dend = min(max(dend + roi_start_d, 0), depth);
bool is_empty = (hend <= hstart) || (wend <= wstart) || (dend <= dstart);

// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;
const T* offset_input =
    input + (roi_batch_ind * channels + c) * height * width * depth;
for (int h = hstart; h < hend; ++h) {
    for (int w = wstart; w < wend; ++w) {
      for (int d = dstart; d < dend; ++d) {

        int input_index = d*width*height + h * width + w; //h*depth*width + w*depth + d; //
        if (offset_input[input_index] > maxval) {
          maxval = offset_input[input_index];
          maxidx = input_index;
        }
    }
  }
}
output[index] = maxval;
argmax_data[index] = maxidx;

}
}

template
global void RoIPoolBackward(
const int nthreads,
const T* grad_output,
const int* argmax_data,
const int channels,
const int height,
const int width,
const int depth,
const int pooled_height,
const int pooled_width,
const int pooled_depth,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride,
const int d_stride) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int pd = (index / (pooled_widthpooled_height)) % pooled_depth;
int c = (index / (pooled_width * pooled_height
pooled_depth)) % channels;
int n = index / (pooled_width * pooled_height* pooled_depth * channels);
//int c = (index / pooled_width / pooled_height) % channels;
//int n = index / pooled_width / pooled_height / channels;

const T* offset_rois = rois + n * 7;
int roi_batch_ind = offset_rois[0];
T* grad_input_offset =
    grad_input + ((roi_batch_ind * channels + c) * height * width * depth);

int output_offset = n * n_stride + c * c_stride;
const int* argmax_data_offset =
    argmax_data + (n * channels + c) * pooled_height * pooled_width* pooled_depth;
int argmax = argmax_data_offset[pd * pooled_height * pooled_width + ph * pooled_width + pw];

if (argmax != -1) {
  atomicAdd(
      grad_input_offset + argmax,
      static_cast<T>(
          grad_output[output_offset + ph * h_stride + pw * w_stride + pd * d_stride]));
}

}
}

std::tuple<torch::Tensor, torch::Tensor> roi_pool_forward3d_cuda(const torch::Tensor& input,
const torch::Tensor& rois,
const float spatial_scale,
const int output_size) {
AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");

const int num_rois = rois.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int depth = input.size(4);

const int pooling_width = output_size;
const int pooling_height = output_size;
const int pooling_depth = output_size;

const auto total_size = num_rois * pooling_height * pooling_width * pooling_depth * channels;

auto output = torch::empty(
    {num_rois, channels, pooling_height, pooling_width, pooling_depth}, input.options());
auto argmax = torch::zeros(
    {num_rois, channels, pooling_height, pooling_width, pooling_depth},
    input.options().dtype(torch::kInt));

const dim3 grid(std::min((total_size + 512 - 1) / 512, 4*4096));
const dim3 block(512);

if (output.numel() == 0) {
    return std::make_tuple(output, argmax);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "RoIPool_forward", [&] {
    RoIPoolForward<scalar_t><<<grid, block>>>(
        total_size,
        input.contiguous().data_ptr<scalar_t>(),
        spatial_scale,
        channels,
        height,
        width,
        depth,
        pooling_width,
        pooling_height,
        pooling_depth,
        rois.contiguous().data_ptr<scalar_t>(),
        output.data_ptr<scalar_t>(),
        argmax.data_ptr<int>());
});

return std::make_tuple(output, argmax);

}

torch::Tensor roi_pool_backward3d_cuda(const torch::Tensor& grad,
const torch::Tensor& argmax,
const torch::Tensor& input_size,
const torch::Tensor& rois) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(grad.is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(argmax.is_cuda(), "argmax must be a CUDA tensor");

auto input_size_a = input_size.accessor<int,1>();
const int batch_size = input_size_a[0];
const int channels = input_size_a[1];
const int height = input_size_a[2];
const int width = input_size_a[3];
const int depth = input_size_a[4];

const int num_rois = argmax.size(0);

const int pooling_width = argmax.size(2);
const int pooling_height = argmax.size(3);
const int pooling_depth = argmax.size(4);

const auto total_size = num_rois * pooling_height * pooling_width * pooling_depth * channels;

auto grad_input =
    torch::zeros({batch_size, channels,  width, depth, height}, grad.options());

const dim3 grid(std::min((total_size + 512 - 1) / 512, 4*4096));
const dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
    return grad_input;
}

// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
int d_stride = grad.stride(4);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "RoIPool_backward", [&] {
    RoIPoolBackward<scalar_t><<<grid, block>>>(
        grad.numel(),
        grad.data_ptr<scalar_t>(),
        argmax.contiguous().data_ptr<int>(),
        channels,
        height,
        width,
        depth,
        pooling_width,
        pooling_height,
        pooling_depth,
        grad_input.data_ptr<scalar_t>(),
        rois.contiguous().data_ptr<scalar_t>(),
        n_stride,
        c_stride,
        h_stride,
        w_stride,
        d_stride);
});

return grad_input;

}
``

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant