tflite_wrapper.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #include "tflite_wrapper.hpp"
  2. #include <iostream>
  3. #include <opencv2/opencv.hpp>
  4. #include "utils/log_util.h"
  5. #ifdef __ANDROID__
  6. #include <tensorflow/lite/delegates/gpu/delegate.h>
  7. #else
  8. #include "plat_delegate.h"
  9. #endif
  10. //https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_c
  11. TfliteWrapper::TfliteWrapper(const char *modelPath) {
  12. #ifdef DEBUG
  13. std::cout << "TfliteInfer()" << std::endl;
  14. #endif
  15. #ifdef __ANDROID__
  16. model = tflite::FlatBufferModel::BuildFromFile(modelPath);
  17. assert(model);
  18. tflite::ops::builtin::BuiltinOpResolver docResolver;
  19. tflite::InterpreterBuilder(*model, docResolver)(&interpreter);
  20. TfLiteGpuDelegateOptionsV2 options;
  21. options.is_precision_loss_allowed = 1;
  22. options.inference_preference =
  23. TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
  24. options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
  25. options.inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
  26. options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
  27. // options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY;
  28. options.max_delegated_partitions = 1;
  29. options.model_token = nullptr;
  30. options.serialization_dir = nullptr;
  31. delegate = TfLiteGpuDelegateV2Create(&options);
  32. assert(interpreter->ModifyGraphWithDelegate(delegate) == kTfLiteOk);
  33. interpreter->AllocateTensors();
  34. #else
  35. model = TfLiteModelCreateFromFile(modelPath);
  36. TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
  37. delegate = (TfLiteDelegate *)create_delegate();
  38. TfLiteInterpreterOptionsAddDelegate(options, delegate);
  39. // Create the interpreter.
  40. interpreter = TfLiteInterpreterCreate(model, options);
  41. // Allocate tensors and populate the input tensor data.
  42. TfLiteInterpreterAllocateTensors(interpreter);
  43. TfLiteInterpreterOptionsDelete(options);
  44. #endif
  45. }
  46. void TfliteWrapper::Invoke(const void *input_data, size_t input_data_size, void *output_data,
  47. size_t output_data_size) {
  48. //#ifdef DEBUG
  49. // int64 startTime = cv::getTickCount();
  50. //#endif
  51. #ifdef __ANDROID__
  52. auto* input = interpreter->typed_input_tensor<float>(0);
  53. memcpy(input, input_data, input_data_size);
  54. interpreter->Invoke();
  55. auto* output = interpreter->typed_output_tensor<float>(0);
  56. memcpy(output_data, output, output_data_size);
  57. #else
  58. TfLiteTensor *input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
  59. const TfLiteTensor *output_tensor = TfLiteInterpreterGetOutputTensor(interpreter, 0);
  60. TfLiteTensorCopyFromBuffer(input_tensor, input_data, input_data_size);
  61. auto status = TfLiteInterpreterInvoke(interpreter);
  62. assert(status == TfLiteStatus::kTfLiteOk);
  63. TfLiteTensorCopyToBuffer(output_tensor, output_data, output_data_size);
  64. #endif
  65. //#ifdef DEBUG
  66. // LOGW("tflite", "Invoke time=%f", (cv::getTickCount() - startTime) / cv::getTickFrequency());
  67. //#endif
  68. }
  69. TfliteWrapper::~TfliteWrapper() {
  70. #ifdef DEBUG
  71. std::cout << "~TfliteInfer()" << std::endl;
  72. #endif
  73. #ifdef __ANDROID__
  74. TfLiteGpuDelegateV2Delete(delegate);
  75. #else
  76. delete_delegate(delegate);
  77. TfLiteModelDelete(model);
  78. TfLiteInterpreterDelete(interpreter);
  79. #endif
  80. }
  81. void TfliteWrapper::getInputShape(uint32_t *shape, size_t size) {
  82. #ifdef __ANDROID__
  83. auto *data = interpreter->input_tensor(0)->dims->data;
  84. for (size_t i = 0; i < size; i++) {
  85. shape[i] = *(data + i);
  86. }
  87. #else
  88. auto *inputTensorPtr = TfLiteInterpreterGetInputTensor(interpreter, 0);
  89. auto dimsNum = TfLiteTensorNumDims(inputTensorPtr);
  90. assert(size == dimsNum);
  91. for (int i = 0; i < dimsNum; i++) {
  92. shape[i] = TfLiteTensorDim(inputTensorPtr, i);
  93. }
  94. #endif
  95. }