keypoint_detector.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <ctime>
  16. #include <memory>
  17. #include <string>
  18. #include <utility>
  19. #include <vector>
  20. #include <opencv2/core/core.hpp>
  21. #include <opencv2/highgui/highgui.hpp>
  22. #include <opencv2/imgproc/imgproc.hpp>
  23. #include <inference_engine.hpp>
  24. #include "keypoint_postprocess.h"
  25. namespace PaddleDetection {
  26. // Object KeyPoint Result
  27. struct KeyPointResult {
  28. // Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
  29. std::vector<float> keypoints;
  30. int num_joints = -1;
  31. };
  32. // Visualiztion KeyPoint Result
  33. cv::Mat VisualizeKptsResult(const cv::Mat& img,
  34. const std::vector<KeyPointResult>& results,
  35. const std::vector<int>& colormap,
  36. float threshold = 0.2);
  37. class KeyPointDetector {
  38. public:
  39. explicit KeyPointDetector(const std::string& model_path,
  40. int input_height = 256,
  41. int input_width = 192,
  42. float score_threshold = 0.3,
  43. const int batch_size = 1,
  44. bool use_dark = true) {
  45. use_dark_ = use_dark;
  46. in_w = input_width;
  47. in_h = input_height;
  48. threshold_ = score_threshold;
  49. InferenceEngine::Core ie;
  50. auto model = ie.ReadNetwork(model_path);
  51. // prepare input settings
  52. InferenceEngine::InputsDataMap inputs_map(model.getInputsInfo());
  53. input_name_ = inputs_map.begin()->first;
  54. InferenceEngine::InputInfo::Ptr input_info = inputs_map.begin()->second;
  55. // prepare output settings
  56. InferenceEngine::OutputsDataMap outputs_map(model.getOutputsInfo());
  57. int idx = 0;
  58. for (auto& output_info : outputs_map) {
  59. if (idx == 0) {
  60. output_info.second->setPrecision(InferenceEngine::Precision::FP32);
  61. } else {
  62. output_info.second->setPrecision(InferenceEngine::Precision::FP32);
  63. }
  64. idx++;
  65. }
  66. // get network
  67. network_ = ie.LoadNetwork(model, "CPU");
  68. infer_request_ = network_.CreateInferRequest();
  69. }
  70. // Load Paddle inference model
  71. void LoadModel(std::string model_file, int num_theads);
  72. // Run predictor
  73. void Predict(const std::vector<cv::Mat> imgs,
  74. std::vector<std::vector<float>>& center,
  75. std::vector<std::vector<float>>& scale,
  76. std::vector<KeyPointResult>* result = nullptr);
  77. bool use_dark() { return this->use_dark_; }
  78. inline float get_threshold() { return threshold_; };
  79. int in_w = 128;
  80. int in_h = 256;
  81. private:
  82. // Postprocess result
  83. void Postprocess(std::vector<float>& output,
  84. std::vector<uint64_t>& output_shape,
  85. std::vector<float>& idxout,
  86. std::vector<uint64_t>& idx_shape,
  87. std::vector<KeyPointResult>* result,
  88. std::vector<std::vector<float>>& center,
  89. std::vector<std::vector<float>>& scale);
  90. std::vector<float> output_data_;
  91. std::vector<float> idx_data_;
  92. float threshold_;
  93. bool use_dark_;
  94. InferenceEngine::ExecutableNetwork network_;
  95. InferenceEngine::InferRequest infer_request_;
  96. std::string input_name_;
  97. };
  98. } // namespace PaddleDetection