roi_align_rotated.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. // This code is refer from:
  2. // https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
  3. #include <cassert>
  4. #include <cmath>
  5. #include <vector>
  6. #include "paddle/extension.h"
  7. #include <cuda.h>
  8. #define CUDA_1D_KERNEL_LOOP(i, n) \
  9. for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
  10. i += blockDim.x * gridDim.x)
  11. #define THREADS_PER_BLOCK 512
  12. inline int GET_BLOCKS(const int N) {
  13. int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
  14. int max_block_num = 4096;
  15. return min(optimal_block_num, max_block_num);
  16. }
  17. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
  18. static __inline__ __device__ double atomicAdd(double *address, double val) {
  19. unsigned long long int *address_as_ull = (unsigned long long int *)address;
  20. unsigned long long int old = *address_as_ull, assumed;
  21. if (val == 0.0)
  22. return __longlong_as_double(old);
  23. do {
  24. assumed = old;
  25. old = atomicCAS(address_as_ull, assumed,
  26. __double_as_longlong(val + __longlong_as_double(assumed)));
  27. } while (assumed != old);
  28. return __longlong_as_double(old);
  29. }
  30. #endif
  31. template <typename T>
  32. __device__ T bilinear_interpolate(const T *input, const int height,
  33. const int width, T y, T x,
  34. const int index /* index for debug only*/) {
  35. // deal with cases that inverse elements are out of feature map boundary
  36. if (y < -1.0 || y > height || x < -1.0 || x > width)
  37. return 0;
  38. if (y <= 0)
  39. y = 0;
  40. if (x <= 0)
  41. x = 0;
  42. int y_low = (int)y;
  43. int x_low = (int)x;
  44. int y_high;
  45. int x_high;
  46. if (y_low >= height - 1) {
  47. y_high = y_low = height - 1;
  48. y = (T)y_low;
  49. } else {
  50. y_high = y_low + 1;
  51. }
  52. if (x_low >= width - 1) {
  53. x_high = x_low = width - 1;
  54. x = (T)x_low;
  55. } else {
  56. x_high = x_low + 1;
  57. }
  58. T ly = y - y_low;
  59. T lx = x - x_low;
  60. T hy = 1. - ly, hx = 1. - lx;
  61. // do bilinear interpolation
  62. T v1 = input[y_low * width + x_low];
  63. T v2 = input[y_low * width + x_high];
  64. T v3 = input[y_high * width + x_low];
  65. T v4 = input[y_high * width + x_high];
  66. T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  67. T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  68. return val;
  69. }
  70. template <typename T>
  71. __device__ void
  72. bilinear_interpolate_gradient(const int height, const int width, T y, T x,
  73. T &w1, T &w2, T &w3, T &w4, int &x_low,
  74. int &x_high, int &y_low, int &y_high,
  75. const int index /* index for debug only*/) {
  76. // deal with cases that inverse elements are out of feature map boundary
  77. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  78. // empty
  79. w1 = w2 = w3 = w4 = 0.;
  80. x_low = x_high = y_low = y_high = -1;
  81. return;
  82. }
  83. if (y <= 0)
  84. y = 0;
  85. if (x <= 0)
  86. x = 0;
  87. y_low = (int)y;
  88. x_low = (int)x;
  89. if (y_low >= height - 1) {
  90. y_high = y_low = height - 1;
  91. y = (T)y_low;
  92. } else {
  93. y_high = y_low + 1;
  94. }
  95. if (x_low >= width - 1) {
  96. x_high = x_low = width - 1;
  97. x = (T)x_low;
  98. } else {
  99. x_high = x_low + 1;
  100. }
  101. T ly = y - y_low;
  102. T lx = x - x_low;
  103. T hy = 1. - ly, hx = 1. - lx;
  104. // reference in forward
  105. // T v1 = input[y_low * width + x_low];
  106. // T v2 = input[y_low * width + x_high];
  107. // T v3 = input[y_high * width + x_low];
  108. // T v4 = input[y_high * width + x_high];
  109. // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  110. w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  111. return;
  112. }
  113. /*** Forward ***/
  114. template <typename scalar_t>
  115. __global__ void roi_align_rotated_cuda_forward_kernel(
  116. const int nthreads, const scalar_t *bottom_data,
  117. const scalar_t *bottom_rois, const scalar_t spatial_scale,
  118. const int sample_num, const bool aligned, const bool clockwise,
  119. const int channels, const int height, const int width,
  120. const int pooled_height, const int pooled_width, scalar_t *top_data) {
  121. CUDA_1D_KERNEL_LOOP(index, nthreads) {
  122. // (n, c, ph, pw) is an element in the pooled output
  123. int pw = index % pooled_width;
  124. int ph = (index / pooled_width) % pooled_height;
  125. int c = (index / pooled_width / pooled_height) % channels;
  126. int n = index / pooled_width / pooled_height / channels;
  127. const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
  128. int roi_batch_ind = offset_bottom_rois[0];
  129. // Do not using rounding; this implementation detail is critical
  130. scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
  131. scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
  132. scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
  133. scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
  134. scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
  135. // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
  136. scalar_t theta = offset_bottom_rois[5];
  137. if (clockwise) {
  138. theta = -theta; // If clockwise, the angle needs to be reversed.
  139. }
  140. if (!aligned) { // for backward-compatibility only
  141. // Force malformed ROIs to be 1x1
  142. roi_width = max(roi_width, (scalar_t)1.);
  143. roi_height = max(roi_height, (scalar_t)1.);
  144. }
  145. scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
  146. static_cast<scalar_t>(pooled_height);
  147. scalar_t bin_size_w =
  148. static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
  149. const scalar_t *offset_bottom_data =
  150. bottom_data + (roi_batch_ind * channels + c) * height * width;
  151. // We use roi_bin_grid to sample the grid and mimic integral
  152. int roi_bin_grid_h = (sample_num > 0)
  153. ? sample_num
  154. : ceilf(roi_height / pooled_height); // e.g., = 2
  155. int roi_bin_grid_w =
  156. (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
  157. // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
  158. // Appropriate translation needs to be applied after.
  159. scalar_t roi_start_h = -roi_height / 2.0;
  160. scalar_t roi_start_w = -roi_width / 2.0;
  161. scalar_t cosscalar_theta = cos(theta);
  162. scalar_t sinscalar_theta = sin(theta);
  163. // We do average (integral) pooling inside a bin
  164. const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
  165. scalar_t output_val = 0.;
  166. for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
  167. const scalar_t yy =
  168. roi_start_h + ph * bin_size_h +
  169. static_cast<scalar_t>(iy + .5f) * bin_size_h /
  170. static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
  171. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  172. const scalar_t xx = roi_start_w + pw * bin_size_w +
  173. static_cast<scalar_t>(ix + .5f) * bin_size_w /
  174. static_cast<scalar_t>(roi_bin_grid_w);
  175. // Rotate by theta (counterclockwise) around the center and translate
  176. scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
  177. scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
  178. scalar_t val = bilinear_interpolate<scalar_t>(
  179. offset_bottom_data, height, width, y, x, index);
  180. output_val += val;
  181. }
  182. }
  183. output_val /= count;
  184. top_data[index] = output_val;
  185. }
  186. }
  187. /*** Backward ***/
  188. template <typename scalar_t>
  189. __global__ void roi_align_rotated_backward_cuda_kernel(
  190. const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
  191. const scalar_t spatial_scale, const int sample_num, const bool aligned,
  192. const bool clockwise, const int channels, const int height, const int width,
  193. const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
  194. CUDA_1D_KERNEL_LOOP(index, nthreads) {
  195. // (n, c, ph, pw) is an element in the pooled output
  196. int pw = index % pooled_width;
  197. int ph = (index / pooled_width) % pooled_height;
  198. int c = (index / pooled_width / pooled_height) % channels;
  199. int n = index / pooled_width / pooled_height / channels;
  200. const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
  201. int roi_batch_ind = offset_bottom_rois[0];
  202. // Do not round
  203. scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
  204. scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
  205. scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
  206. scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
  207. scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
  208. // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
  209. scalar_t theta = offset_bottom_rois[5];
  210. if (clockwise) {
  211. theta = -theta; // If clockwise, the angle needs to be reversed.
  212. }
  213. if (!aligned) { // for backward-compatibility only
  214. // Force malformed ROIs to be 1x1
  215. roi_width = max(roi_width, (scalar_t)1.);
  216. roi_height = max(roi_height, (scalar_t)1.);
  217. }
  218. scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
  219. static_cast<scalar_t>(pooled_height);
  220. scalar_t bin_size_w =
  221. static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
  222. scalar_t *offset_bottom_diff =
  223. bottom_diff + (roi_batch_ind * channels + c) * height * width;
  224. int top_offset = (n * channels + c) * pooled_height * pooled_width;
  225. const scalar_t *offset_top_diff = top_diff + top_offset;
  226. const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
  227. // We use roi_bin_grid to sample the grid and mimic integral
  228. int roi_bin_grid_h = (sample_num > 0)
  229. ? sample_num
  230. : ceilf(roi_height / pooled_height); // e.g., = 2
  231. int roi_bin_grid_w =
  232. (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
  233. // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
  234. // Appropriate translation needs to be applied after.
  235. scalar_t roi_start_h = -roi_height / 2.0;
  236. scalar_t roi_start_w = -roi_width / 2.0;
  237. scalar_t cosTheta = cos(theta);
  238. scalar_t sinTheta = sin(theta);
  239. // We do average (integral) pooling inside a bin
  240. const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
  241. for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
  242. const scalar_t yy =
  243. roi_start_h + ph * bin_size_h +
  244. static_cast<scalar_t>(iy + .5f) * bin_size_h /
  245. static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
  246. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  247. const scalar_t xx = roi_start_w + pw * bin_size_w +
  248. static_cast<scalar_t>(ix + .5f) * bin_size_w /
  249. static_cast<scalar_t>(roi_bin_grid_w);
  250. // Rotate by theta around the center and translate
  251. scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
  252. scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
  253. scalar_t w1, w2, w3, w4;
  254. int x_low, x_high, y_low, y_high;
  255. bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
  256. w4, x_low, x_high, y_low,
  257. y_high, index);
  258. scalar_t g1 = top_diff_this_bin * w1 / count;
  259. scalar_t g2 = top_diff_this_bin * w2 / count;
  260. scalar_t g3 = top_diff_this_bin * w3 / count;
  261. scalar_t g4 = top_diff_this_bin * w4 / count;
  262. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  263. atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
  264. atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
  265. atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
  266. atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
  267. } // if
  268. } // ix
  269. } // iy
  270. } // CUDA_1D_KERNEL_LOOP
  271. } // RoIAlignBackward
  272. std::vector<paddle::Tensor>
  273. RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
  274. const paddle::Tensor &rois, int aligned_height,
  275. int aligned_width, float spatial_scale,
  276. int sampling_ratio, bool aligned, bool clockwise) {
  277. auto num_rois = rois.shape()[0];
  278. auto channels = input.shape()[1];
  279. auto height = input.shape()[2];
  280. auto width = input.shape()[3];
  281. auto output =
  282. paddle::empty({num_rois, channels, aligned_height, aligned_width},
  283. input.type(), paddle::GPUPlace());
  284. auto output_size = output.numel();
  285. PD_DISPATCH_FLOATING_TYPES(
  286. input.type(), "roi_align_rotated_cuda_forward_kernel", ([&] {
  287. roi_align_rotated_cuda_forward_kernel<
  288. data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
  289. output_size, input.data<data_t>(), rois.data<data_t>(),
  290. static_cast<data_t>(spatial_scale), sampling_ratio, aligned,
  291. clockwise, channels, height, width, aligned_height, aligned_width,
  292. output.data<data_t>());
  293. }));
  294. return {output};
  295. }
  296. std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
  297. const paddle::Tensor &input, const paddle::Tensor &rois,
  298. const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
  299. float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
  300. auto num_rois = rois.shape()[0];
  301. auto batch_size = input.shape()[0];
  302. auto channels = input.shape()[1];
  303. auto height = input.shape()[2];
  304. auto width = input.shape()[3];
  305. auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
  306. input.type(), paddle::GPUPlace());
  307. const int output_size = num_rois * aligned_height * aligned_width * channels;
  308. PD_DISPATCH_FLOATING_TYPES(
  309. grad_output.type(), "roi_align_rotated_backward_cuda_kernel", ([&] {
  310. roi_align_rotated_backward_cuda_kernel<
  311. data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
  312. output_size, grad_output.data<data_t>(), rois.data<data_t>(),
  313. spatial_scale, sampling_ratio, aligned, clockwise, channels, height,
  314. width, aligned_height, aligned_width, grad_input.data<data_t>());
  315. }));
  316. return {grad_input};
  317. }