test_base.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import unittest
  16. import contextlib
  17. import paddle
  18. from paddle.static import Program
  19. class LayerTest(unittest.TestCase):
  20. @classmethod
  21. def setUpClass(cls):
  22. cls.seed = 111
  23. @classmethod
  24. def tearDownClass(cls):
  25. pass
  26. def _get_place(self, force_to_use_cpu=False):
  27. # this option for ops that only have cpu kernel
  28. if force_to_use_cpu:
  29. return 'cpu'
  30. else:
  31. return paddle.device.get_device()
  32. @contextlib.contextmanager
  33. def static_graph(self):
  34. paddle.enable_static()
  35. scope = paddle.static.Scope()
  36. program = Program()
  37. with paddle.static.scope_guard(scope):
  38. with paddle.static.program_guard(program):
  39. paddle.seed(self.seed)
  40. paddle.framework.random._manual_program_seed(self.seed)
  41. yield
  42. def get_static_graph_result(self,
  43. feed,
  44. fetch_list,
  45. with_lod=False,
  46. force_to_use_cpu=False):
  47. exe = paddle.static.Executor(self._get_place(force_to_use_cpu))
  48. exe.run(paddle.static.default_startup_program())
  49. return exe.run(paddle.static.default_main_program(),
  50. feed=feed,
  51. fetch_list=fetch_list,
  52. return_numpy=(not with_lod))
  53. @contextlib.contextmanager
  54. def dynamic_graph(self, force_to_use_cpu=False):
  55. paddle.disable_static()
  56. place = self._get_place(force_to_use_cpu=force_to_use_cpu)
  57. paddle.device.set_device(place)
  58. paddle.seed(self.seed)
  59. paddle.framework.random._manual_program_seed(self.seed)
  60. yield