matched_rbox_iou.cu 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // The code is based on
  16. // https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/csrc/box_iou_rotated/
  17. #include "../rbox_iou/rbox_iou_utils.h"
  18. #include "paddle/extension.h"
  19. template <typename T>
  20. __global__ void
  21. matched_rbox_iou_cuda_kernel(const int rbox_num, const T *rbox1_data_ptr,
  22. const T *rbox2_data_ptr, T *output_data_ptr) {
  23. for (int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < rbox_num;
  24. tid += blockDim.x * gridDim.x) {
  25. output_data_ptr[tid] =
  26. rbox_iou_single<T>(rbox1_data_ptr + tid * 5, rbox2_data_ptr + tid * 5);
  27. }
  28. }
  29. #define CHECK_INPUT_GPU(x) \
  30. PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
  31. std::vector<paddle::Tensor>
  32. MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1,
  33. const paddle::Tensor &rbox2) {
  34. CHECK_INPUT_GPU(rbox1);
  35. CHECK_INPUT_GPU(rbox2);
  36. PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim");
  37. auto rbox_num = rbox1.shape()[0];
  38. auto output = paddle::empty({rbox_num}, rbox1.dtype(), paddle::GPUPlace());
  39. const int thread_per_block = 512;
  40. const int block_per_grid = CeilDiv(rbox_num, thread_per_block);
  41. PD_DISPATCH_FLOATING_TYPES(
  42. rbox1.type(), "matched_rbox_iou_cuda_kernel", ([&] {
  43. matched_rbox_iou_cuda_kernel<
  44. data_t><<<block_per_grid, thread_per_block, 0, rbox1.stream()>>>(
  45. rbox_num, rbox1.data<data_t>(), rbox2.data<data_t>(),
  46. output.data<data_t>());
  47. }));
  48. return {output};
  49. }