keypoint_detector.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <ctime>
  16. #include <memory>
  17. #include <string>
  18. #include <utility>
  19. #include <vector>
  20. #include <opencv2/core/core.hpp>
  21. #include <opencv2/highgui/highgui.hpp>
  22. #include <opencv2/imgproc/imgproc.hpp>
  23. #include "Interpreter.hpp"
  24. #include "ImageProcess.hpp"
  25. #include "MNNDefine.h"
  26. #include "Tensor.hpp"
  27. #include "keypoint_postprocess.h"
  28. using namespace MNN;
  29. namespace PaddleDetection {
  30. // Object KeyPoint Result
  31. struct KeyPointResult {
  32. // Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
  33. std::vector<float> keypoints;
  34. int num_joints = -1;
  35. };
  36. // Visualiztion KeyPoint Result
  37. cv::Mat VisualizeKptsResult(const cv::Mat& img,
  38. const std::vector<KeyPointResult>& results,
  39. const std::vector<int>& colormap,
  40. float threshold = 0.2);
  41. class KeyPointDetector {
  42. public:
  43. explicit KeyPointDetector(const std::string& model_path,
  44. int num_thread = 4,
  45. int input_height = 256,
  46. int input_width = 192,
  47. float score_threshold = 0.3,
  48. const int batch_size = 1,
  49. bool use_dark = true) {
  50. printf("config path: %s",
  51. model_path.substr(0, model_path.find_last_of('/') + 1).c_str());
  52. use_dark_ = use_dark;
  53. in_w = input_width;
  54. in_h = input_height;
  55. threshold_ = score_threshold;
  56. KeyPointDet_interpreter = std::shared_ptr<MNN::Interpreter>(
  57. MNN::Interpreter::createFromFile(model_path.c_str()));
  58. MNN::ScheduleConfig config;
  59. config.type = MNN_FORWARD_CPU;
  60. /*modeNum means gpuMode for GPU usage, Or means numThread for CPU usage.*/
  61. config.numThread = num_thread;
  62. // If type not fount, let it failed
  63. config.backupType = MNN_FORWARD_CPU;
  64. BackendConfig backendConfig;
  65. backendConfig.precision = static_cast<MNN::BackendConfig::PrecisionMode>(1);
  66. config.backendConfig = &backendConfig;
  67. KeyPointDet_session = KeyPointDet_interpreter->createSession(config);
  68. input_tensor =
  69. KeyPointDet_interpreter->getSessionInput(KeyPointDet_session, nullptr);
  70. }
  71. ~KeyPointDetector() {
  72. KeyPointDet_interpreter->releaseModel();
  73. KeyPointDet_interpreter->releaseSession(KeyPointDet_session);
  74. }
  75. // Load Paddle inference model
  76. void LoadModel(std::string model_file, int num_theads);
  77. // Run predictor
  78. void Predict(const std::vector<cv::Mat> imgs,
  79. std::vector<std::vector<float>>& center,
  80. std::vector<std::vector<float>>& scale,
  81. std::vector<KeyPointResult>* result = nullptr);
  82. bool use_dark() { return this->use_dark_; }
  83. inline float get_threshold() { return threshold_; };
  84. // const float mean_vals[3] = { 103.53f, 116.28f, 123.675f };
  85. // const float norm_vals[3] = { 0.017429f, 0.017507f, 0.017125f };
  86. const float mean_vals[3] = {0.f, 0.f, 0.f};
  87. const float norm_vals[3] = {1.f, 1.f, 1.f};
  88. int in_w = 128;
  89. int in_h = 256;
  90. private:
  91. // Postprocess result
  92. void Postprocess(std::vector<float>& output,
  93. std::vector<int>& output_shape,
  94. std::vector<int>& idxout,
  95. std::vector<int>& idx_shape,
  96. std::vector<KeyPointResult>* result,
  97. std::vector<std::vector<float>>& center,
  98. std::vector<std::vector<float>>& scale);
  99. std::vector<float> output_data_;
  100. std::vector<int> idx_data_;
  101. float threshold_;
  102. bool use_dark_;
  103. std::shared_ptr<MNN::Interpreter> KeyPointDet_interpreter;
  104. MNN::Session* KeyPointDet_session = nullptr;
  105. MNN::Tensor* input_tensor = nullptr;
  106. };
  107. } // namespace PaddleDetection