keypoint_detector.cpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <sstream>
  15. // for setprecision
  16. #include <chrono>
  17. #include <iomanip>
  18. #include "keypoint_detector.h"
  19. namespace PaddleDetection {
  20. // Visualiztion MaskDetector results
  21. cv::Mat VisualizeKptsResult(const cv::Mat& img,
  22. const std::vector<KeyPointResult>& results,
  23. const std::vector<int>& colormap,
  24. float threshold) {
  25. const int edge[][2] = {{0, 1},
  26. {0, 2},
  27. {1, 3},
  28. {2, 4},
  29. {3, 5},
  30. {4, 6},
  31. {5, 7},
  32. {6, 8},
  33. {7, 9},
  34. {8, 10},
  35. {5, 11},
  36. {6, 12},
  37. {11, 13},
  38. {12, 14},
  39. {13, 15},
  40. {14, 16},
  41. {11, 12}};
  42. cv::Mat vis_img = img.clone();
  43. for (int batchid = 0; batchid < results.size(); batchid++) {
  44. for (int i = 0; i < results[batchid].num_joints; i++) {
  45. if (results[batchid].keypoints[i * 3] > threshold) {
  46. int x_coord = int(results[batchid].keypoints[i * 3 + 1]);
  47. int y_coord = int(results[batchid].keypoints[i * 3 + 2]);
  48. cv::circle(vis_img,
  49. cv::Point2d(x_coord, y_coord),
  50. 1,
  51. cv::Scalar(0, 0, 255),
  52. 2);
  53. }
  54. }
  55. for (int i = 0; i < results[batchid].num_joints; i++) {
  56. if (results[batchid].keypoints[edge[i][0] * 3] > threshold &&
  57. results[batchid].keypoints[edge[i][1] * 3] > threshold) {
  58. int x_start = int(results[batchid].keypoints[edge[i][0] * 3 + 1]);
  59. int y_start = int(results[batchid].keypoints[edge[i][0] * 3 + 2]);
  60. int x_end = int(results[batchid].keypoints[edge[i][1] * 3 + 1]);
  61. int y_end = int(results[batchid].keypoints[edge[i][1] * 3 + 2]);
  62. cv::line(vis_img,
  63. cv::Point2d(x_start, y_start),
  64. cv::Point2d(x_end, y_end),
  65. colormap[i],
  66. 1);
  67. }
  68. }
  69. }
  70. return vis_img;
  71. }
  72. void KeyPointDetector::Postprocess(std::vector<float>& output,
  73. std::vector<int>& output_shape,
  74. std::vector<int>& idxout,
  75. std::vector<int>& idx_shape,
  76. std::vector<KeyPointResult>* result,
  77. std::vector<std::vector<float>>& center_bs,
  78. std::vector<std::vector<float>>& scale_bs) {
  79. std::vector<float> preds(output_shape[1] * 3, 0);
  80. for (int batchid = 0; batchid < output_shape[0]; batchid++) {
  81. get_final_preds(output,
  82. output_shape,
  83. idxout,
  84. idx_shape,
  85. center_bs[batchid],
  86. scale_bs[batchid],
  87. preds,
  88. batchid,
  89. this->use_dark());
  90. KeyPointResult result_item;
  91. result_item.num_joints = output_shape[1];
  92. result_item.keypoints.clear();
  93. for (int i = 0; i < output_shape[1]; i++) {
  94. result_item.keypoints.emplace_back(preds[i * 3]);
  95. result_item.keypoints.emplace_back(preds[i * 3 + 1]);
  96. result_item.keypoints.emplace_back(preds[i * 3 + 2]);
  97. }
  98. result->push_back(result_item);
  99. }
  100. }
  101. void KeyPointDetector::Predict(const std::vector<cv::Mat> imgs,
  102. std::vector<std::vector<float>>& center_bs,
  103. std::vector<std::vector<float>>& scale_bs,
  104. std::vector<KeyPointResult>* result) {
  105. int batch_size = imgs.size();
  106. KeyPointDet_interpreter->resizeTensor(input_tensor,
  107. {batch_size, 3, in_h, in_w});
  108. KeyPointDet_interpreter->resizeSession(KeyPointDet_session);
  109. auto insize = 3 * in_h * in_w;
  110. // Preprocess image
  111. cv::Mat resized_im;
  112. for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
  113. cv::Mat im = imgs.at(bs_idx);
  114. cv::resize(im, resized_im, cv::Size(in_w, in_h));
  115. std::shared_ptr<MNN::CV::ImageProcess> pretreat(
  116. MNN::CV::ImageProcess::create(
  117. MNN::CV::BGR, MNN::CV::RGB, mean_vals, 3, norm_vals, 3));
  118. pretreat->convert(
  119. resized_im.data, in_w, in_h, resized_im.step[0], input_tensor);
  120. }
  121. // Run predictor
  122. auto inference_start = std::chrono::steady_clock::now();
  123. KeyPointDet_interpreter->runSession(KeyPointDet_session);
  124. // Get output tensor
  125. auto out_tensor = KeyPointDet_interpreter->getSessionOutput(
  126. KeyPointDet_session, "conv2d_441.tmp_1");
  127. auto nchwoutTensor = new Tensor(out_tensor, Tensor::CAFFE);
  128. out_tensor->copyToHostTensor(nchwoutTensor);
  129. auto output_shape = nchwoutTensor->shape();
  130. // Calculate output length
  131. int output_size = 1;
  132. for (int j = 0; j < output_shape.size(); ++j) {
  133. output_size *= output_shape[j];
  134. }
  135. output_data_.resize(output_size);
  136. std::copy_n(nchwoutTensor->host<float>(), output_size, output_data_.data());
  137. delete nchwoutTensor;
  138. auto idx_tensor = KeyPointDet_interpreter->getSessionOutput(
  139. KeyPointDet_session, "argmax_0.tmp_0");
  140. auto idxhostTensor = new Tensor(idx_tensor, Tensor::CAFFE);
  141. idx_tensor->copyToHostTensor(idxhostTensor);
  142. auto idx_shape = idxhostTensor->shape();
  143. // Calculate output length
  144. output_size = 1;
  145. for (int j = 0; j < idx_shape.size(); ++j) {
  146. output_size *= idx_shape[j];
  147. }
  148. idx_data_.resize(output_size);
  149. std::copy_n(idxhostTensor->host<int>(), output_size, idx_data_.data());
  150. delete idxhostTensor;
  151. auto inference_end = std::chrono::steady_clock::now();
  152. std::chrono::duration<double> elapsed = inference_end - inference_start;
  153. printf("keypoint inference time: %f s\n", elapsed.count());
  154. // Postprocessing result
  155. Postprocess(output_data_,
  156. output_shape,
  157. idx_data_,
  158. idx_shape,
  159. result,
  160. center_bs,
  161. scale_bs);
  162. }
  163. } // namespace PaddleDetection