gplaidmlkernel.hpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. //
  5. // Copyright (C) 2019 Intel Corporation
  6. //
  7. #ifndef OPENCV_GAPI_GPLAIDMLKERNEL_HPP
  8. #define OPENCV_GAPI_GPLAIDMLKERNEL_HPP
  9. #include <opencv2/gapi/gkernel.hpp>
  10. #include <opencv2/gapi/garg.hpp>
  11. namespace plaidml
  12. {
  13. namespace edsl
  14. {
  15. class Tensor;
  16. } // namespace edsl
  17. } // namespace plaidml
  18. namespace cv
  19. {
  20. namespace gapi
  21. {
  22. namespace plaidml
  23. {
  24. GAPI_EXPORTS cv::gapi::GBackend backend();
  25. } // namespace plaidml
  26. } // namespace gapi
  27. struct GPlaidMLContext
  28. {
  29. // Generic accessor API
  30. template<typename T>
  31. const T& inArg(int input) { return m_args.at(input).get<T>(); }
  32. // Syntax sugar
  33. const plaidml::edsl::Tensor& inTensor(int input)
  34. {
  35. return inArg<plaidml::edsl::Tensor>(input);
  36. }
  37. plaidml::edsl::Tensor& outTensor(int output)
  38. {
  39. return *(m_results.at(output).get<plaidml::edsl::Tensor*>());
  40. }
  41. std::vector<GArg> m_args;
  42. std::unordered_map<std::size_t, GArg> m_results;
  43. };
  44. class GAPI_EXPORTS GPlaidMLKernel
  45. {
  46. public:
  47. using F = std::function<void(GPlaidMLContext &)>;
  48. GPlaidMLKernel() = default;
  49. explicit GPlaidMLKernel(const F& f) : m_f(f) {}
  50. void apply(GPlaidMLContext &ctx) const
  51. {
  52. GAPI_Assert(m_f);
  53. m_f(ctx);
  54. }
  55. protected:
  56. F m_f;
  57. };
  58. namespace detail
  59. {
  60. template<class T> struct plaidml_get_in;
  61. template<> struct plaidml_get_in<cv::GMat>
  62. {
  63. static const plaidml::edsl::Tensor& get(GPlaidMLContext& ctx, int idx)
  64. {
  65. return ctx.inTensor(idx);
  66. }
  67. };
  68. template<class T> struct plaidml_get_in
  69. {
  70. static T get(GPlaidMLContext &ctx, int idx) { return ctx.inArg<T>(idx); }
  71. };
  72. template<class T> struct plaidml_get_out;
  73. template<> struct plaidml_get_out<cv::GMat>
  74. {
  75. static plaidml::edsl::Tensor& get(GPlaidMLContext& ctx, int idx)
  76. {
  77. return ctx.outTensor(idx);
  78. }
  79. };
  80. template<typename, typename, typename>
  81. struct PlaidMLCallHelper;
  82. template<typename Impl, typename... Ins, typename... Outs>
  83. struct PlaidMLCallHelper<Impl, std::tuple<Ins...>, std::tuple<Outs...> >
  84. {
  85. template<int... IIs, int... OIs>
  86. static void call_impl(GPlaidMLContext &ctx, detail::Seq<IIs...>, detail::Seq<OIs...>)
  87. {
  88. Impl::run(plaidml_get_in<Ins>::get(ctx, IIs)..., plaidml_get_out<Outs>::get(ctx, OIs)...);
  89. }
  90. static void call(GPlaidMLContext& ctx)
  91. {
  92. call_impl(ctx,
  93. typename detail::MkSeq<sizeof...(Ins)>::type(),
  94. typename detail::MkSeq<sizeof...(Outs)>::type());
  95. }
  96. };
  97. } // namespace detail
  98. template<class Impl, class K>
  99. class GPlaidMLKernelImpl: public cv::detail::PlaidMLCallHelper<Impl, typename K::InArgs, typename K::OutArgs>,
  100. public cv::detail::KernelTag
  101. {
  102. using P = detail::PlaidMLCallHelper<Impl, typename K::InArgs, typename K::OutArgs>;
  103. public:
  104. using API = K;
  105. static cv::gapi::GBackend backend() { return cv::gapi::plaidml::backend(); }
  106. static cv::GPlaidMLKernel kernel() { return GPlaidMLKernel(&P::call); }
  107. };
  108. #define GAPI_PLAIDML_KERNEL(Name, API) struct Name: public cv::GPlaidMLKernelImpl<Name, API>
  109. } // namespace cv
  110. #endif // OPENCV_GAPI_GPLAIDMLKERNEL_HPP