anchor_cluster.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. from ppdet.utils.logger import setup_logger
  23. logger = setup_logger('ppdet.anchor_cluster')
  24. from scipy.cluster.vq import kmeans
  25. import numpy as np
  26. from tqdm import tqdm
  27. from ppdet.utils.cli import ArgsParser
  28. from ppdet.utils.check import check_gpu, check_version, check_config
  29. from ppdet.core.workspace import load_config, merge_config
  30. class BaseAnchorCluster(object):
  31. def __init__(self, n, cache_path, cache, verbose=True):
  32. """
  33. Base Anchor Cluster
  34. Args:
  35. n (int): number of clusters
  36. cache_path (str): cache directory path
  37. cache (bool): whether using cache
  38. verbose (bool): whether print results
  39. """
  40. super(BaseAnchorCluster, self).__init__()
  41. self.n = n
  42. self.cache_path = cache_path
  43. self.cache = cache
  44. self.verbose = verbose
  45. def print_result(self, centers):
  46. raise NotImplementedError('%s.print_result is not available' %
  47. self.__class__.__name__)
  48. def get_whs(self):
  49. whs_cache_path = os.path.join(self.cache_path, 'whs.npy')
  50. shapes_cache_path = os.path.join(self.cache_path, 'shapes.npy')
  51. if self.cache and os.path.exists(whs_cache_path) and os.path.exists(
  52. shapes_cache_path):
  53. self.whs = np.load(whs_cache_path)
  54. self.shapes = np.load(shapes_cache_path)
  55. return self.whs, self.shapes
  56. whs = np.zeros((0, 2))
  57. shapes = np.zeros((0, 2))
  58. self.dataset.parse_dataset()
  59. roidbs = self.dataset.roidbs
  60. for rec in tqdm(roidbs):
  61. h, w = rec['h'], rec['w']
  62. bbox = rec['gt_bbox']
  63. wh = bbox[:, 2:4] - bbox[:, 0:2] + 1
  64. wh = wh / np.array([[w, h]])
  65. shape = np.ones_like(wh) * np.array([[w, h]])
  66. whs = np.vstack((whs, wh))
  67. shapes = np.vstack((shapes, shape))
  68. if self.cache:
  69. os.makedirs(self.cache_path, exist_ok=True)
  70. np.save(whs_cache_path, whs)
  71. np.save(shapes_cache_path, shapes)
  72. self.whs = whs
  73. self.shapes = shapes
  74. return self.whs, self.shapes
  75. def calc_anchors(self):
  76. raise NotImplementedError('%s.calc_anchors is not available' %
  77. self.__class__.__name__)
  78. def __call__(self):
  79. self.get_whs()
  80. centers = self.calc_anchors()
  81. if self.verbose:
  82. self.print_result(centers)
  83. return centers
  84. class YOLOv2AnchorCluster(BaseAnchorCluster):
  85. def __init__(self,
  86. n,
  87. dataset,
  88. size,
  89. cache_path,
  90. cache,
  91. iters=1000,
  92. verbose=True):
  93. super(YOLOv2AnchorCluster, self).__init__(
  94. n, cache_path, cache, verbose=verbose)
  95. """
  96. YOLOv2 Anchor Cluster
  97. The code is based on https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
  98. Args:
  99. n (int): number of clusters
  100. dataset (DataSet): DataSet instance, VOC or COCO
  101. size (list): [w, h]
  102. cache_path (str): cache directory path
  103. cache (bool): whether using cache
  104. iters (int): kmeans algorithm iters
  105. verbose (bool): whether print results
  106. """
  107. self.dataset = dataset
  108. self.size = size
  109. self.iters = iters
  110. def print_result(self, centers):
  111. logger.info('%d anchor cluster result: [w, h]' % self.n)
  112. for w, h in centers:
  113. logger.info('[%d, %d]' % (round(w), round(h)))
  114. def metric(self, whs, centers):
  115. wh1 = whs[:, None]
  116. wh2 = centers[None]
  117. inter = np.minimum(wh1, wh2).prod(2)
  118. return inter / (wh1.prod(2) + wh2.prod(2) - inter)
  119. def kmeans_expectation(self, whs, centers, assignments):
  120. dist = self.metric(whs, centers)
  121. new_assignments = dist.argmax(1)
  122. converged = (new_assignments == assignments).all()
  123. return converged, new_assignments
  124. def kmeans_maximizations(self, whs, centers, assignments):
  125. new_centers = np.zeros_like(centers)
  126. for i in range(centers.shape[0]):
  127. mask = (assignments == i)
  128. if mask.sum():
  129. new_centers[i, :] = whs[mask].mean(0)
  130. return new_centers
  131. def calc_anchors(self):
  132. self.whs = self.whs * np.array([self.size])
  133. # random select k centers
  134. whs, n, iters = self.whs, self.n, self.iters
  135. logger.info('Running kmeans for %d anchors on %d points...' %
  136. (n, len(whs)))
  137. idx = np.random.choice(whs.shape[0], size=n, replace=False)
  138. centers = whs[idx]
  139. assignments = np.zeros(whs.shape[0:1]) * -1
  140. # kmeans
  141. if n == 1:
  142. return self.kmeans_maximizations(whs, centers, assignments)
  143. pbar = tqdm(range(iters), desc='Cluster anchors with k-means algorithm')
  144. for _ in pbar:
  145. # E step
  146. converged, assignments = self.kmeans_expectation(whs, centers,
  147. assignments)
  148. if converged:
  149. logger.info('kmeans algorithm has converged')
  150. break
  151. # M step
  152. centers = self.kmeans_maximizations(whs, centers, assignments)
  153. ious = self.metric(whs, centers)
  154. pbar.desc = 'avg_iou: %.4f' % (ious.max(1).mean())
  155. centers = sorted(centers, key=lambda x: x[0] * x[1])
  156. return centers
  157. def main():
  158. parser = ArgsParser()
  159. parser.add_argument(
  160. '--n', '-n', default=9, type=int, help='num of clusters')
  161. parser.add_argument(
  162. '--iters',
  163. '-i',
  164. default=1000,
  165. type=int,
  166. help='num of iterations for kmeans')
  167. parser.add_argument(
  168. '--verbose', '-v', default=True, type=bool, help='whether print result')
  169. parser.add_argument(
  170. '--size',
  171. '-s',
  172. default=None,
  173. type=str,
  174. help='image size: w,h, using comma as delimiter')
  175. parser.add_argument(
  176. '--method',
  177. '-m',
  178. default='v2',
  179. type=str,
  180. help='cluster method, v2 is only supported now')
  181. parser.add_argument(
  182. '--cache_path', default='cache', type=str, help='cache path')
  183. parser.add_argument(
  184. '--cache', action='store_true', help='whether use cache')
  185. FLAGS = parser.parse_args()
  186. cfg = load_config(FLAGS.config)
  187. merge_config(FLAGS.opt)
  188. check_config(cfg)
  189. # check if set use_gpu=True in paddlepaddle cpu version
  190. if 'use_gpu' not in cfg:
  191. cfg.use_gpu = False
  192. check_gpu(cfg.use_gpu)
  193. # check if paddlepaddle version is satisfied
  194. check_version('develop')
  195. # get dataset
  196. dataset = cfg['TrainDataset']
  197. if FLAGS.size:
  198. if ',' in FLAGS.size:
  199. size = list(map(int, FLAGS.size.split(',')))
  200. assert len(size) == 2, "the format of size is incorrect"
  201. else:
  202. size = int(FLAGS.size)
  203. size = [size, size]
  204. elif 'inputs_def' in cfg['TestReader'] and 'image_shape' in cfg[
  205. 'TestReader']['inputs_def']:
  206. size = cfg['TestReader']['inputs_def']['image_shape'][1:]
  207. else:
  208. raise ValueError('size is not specified')
  209. if FLAGS.method == 'v2':
  210. cluster = YOLOv2AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path,
  211. FLAGS.cache, FLAGS.iters, FLAGS.verbose)
  212. else:
  213. raise ValueError('cluster method: %s is not supported' % FLAGS.method)
  214. anchors = cluster()
  215. if __name__ == "__main__":
  216. main()