rbox_iou.cu 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // The code is based on
  16. // https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
  17. #include "paddle/extension.h"
  18. #include "rbox_iou_utils.h"
  19. // 2D block with 32 * 16 = 512 threads per block
  20. const int BLOCK_DIM_X = 32;
  21. const int BLOCK_DIM_Y = 16;
  22. template <typename T>
  23. __global__ void rbox_iou_cuda_kernel(const int rbox1_num, const int rbox2_num,
  24. const T *rbox1_data_ptr,
  25. const T *rbox2_data_ptr,
  26. T *output_data_ptr) {
  27. // get row_start and col_start
  28. const int rbox1_block_idx = blockIdx.x * blockDim.x;
  29. const int rbox2_block_idx = blockIdx.y * blockDim.y;
  30. const int rbox1_thread_num = min(rbox1_num - rbox1_block_idx, blockDim.x);
  31. const int rbox2_thread_num = min(rbox2_num - rbox2_block_idx, blockDim.y);
  32. __shared__ T block_boxes1[BLOCK_DIM_X * 5];
  33. __shared__ T block_boxes2[BLOCK_DIM_Y * 5];
  34. // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
  35. if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) {
  36. block_boxes1[threadIdx.x * 5 + 0] =
  37. rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 0];
  38. block_boxes1[threadIdx.x * 5 + 1] =
  39. rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 1];
  40. block_boxes1[threadIdx.x * 5 + 2] =
  41. rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 2];
  42. block_boxes1[threadIdx.x * 5 + 3] =
  43. rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 3];
  44. block_boxes1[threadIdx.x * 5 + 4] =
  45. rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4];
  46. }
  47. // threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as
  48. // above: threadIdx.y == 0
  49. if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) {
  50. block_boxes2[threadIdx.x * 5 + 0] =
  51. rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0];
  52. block_boxes2[threadIdx.x * 5 + 1] =
  53. rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 1];
  54. block_boxes2[threadIdx.x * 5 + 2] =
  55. rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 2];
  56. block_boxes2[threadIdx.x * 5 + 3] =
  57. rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 3];
  58. block_boxes2[threadIdx.x * 5 + 4] =
  59. rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 4];
  60. }
  61. // sync
  62. __syncthreads();
  63. if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) {
  64. int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx +
  65. threadIdx.y;
  66. output_data_ptr[offset] = rbox_iou_single<T>(
  67. block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
  68. }
  69. }
  70. #define CHECK_INPUT_GPU(x) \
  71. PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
  72. std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor &rbox1,
  73. const paddle::Tensor &rbox2) {
  74. CHECK_INPUT_GPU(rbox1);
  75. CHECK_INPUT_GPU(rbox2);
  76. auto rbox1_num = rbox1.shape()[0];
  77. auto rbox2_num = rbox2.shape()[0];
  78. auto output =
  79. paddle::empty({rbox1_num, rbox2_num}, rbox1.dtype(), paddle::GPUPlace());
  80. const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X);
  81. const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y);
  82. dim3 blocks(blocks_x, blocks_y);
  83. dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
  84. PD_DISPATCH_FLOATING_TYPES(
  85. rbox1.type(), "rbox_iou_cuda_kernel", ([&] {
  86. rbox_iou_cuda_kernel<data_t><<<blocks, threads, 0, rbox1.stream()>>>(
  87. rbox1_num, rbox2_num, rbox1.data<data_t>(), rbox2.data<data_t>(),
  88. output.data<data_t>());
  89. }));
  90. return {output};
  91. }