coco.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. try:
  17. from collections.abc import Sequence
  18. except Exception:
  19. from collections import Sequence
  20. import numpy as np
  21. from ppdet.core.workspace import register, serializable
  22. from .dataset import DetDataset
  23. from ppdet.utils.logger import setup_logger
  24. logger = setup_logger(__name__)
  25. __all__ = ['COCODataSet', 'SlicedCOCODataSet', 'SemiCOCODataSet']
  26. @register
  27. @serializable
  28. class COCODataSet(DetDataset):
  29. """
  30. Load dataset with COCO format.
  31. Args:
  32. dataset_dir (str): root directory for dataset.
  33. image_dir (str): directory for images.
  34. anno_path (str): coco annotation file path.
  35. data_fields (list): key name of data dictionary, at least have 'image'.
  36. sample_num (int): number of samples to load, -1 means all.
  37. load_crowd (bool): whether to load crowded ground-truth.
  38. False as default
  39. allow_empty (bool): whether to load empty entry. False as default
  40. empty_ratio (float): the ratio of empty record number to total
  41. record's, if empty_ratio is out of [0. ,1.), do not sample the
  42. records and use all the empty entries. 1. as default
  43. repeat (int): repeat times for dataset, use in benchmark.
  44. """
  45. def __init__(self,
  46. dataset_dir=None,
  47. image_dir=None,
  48. anno_path=None,
  49. data_fields=['image'],
  50. sample_num=-1,
  51. load_crowd=False,
  52. allow_empty=False,
  53. empty_ratio=1.,
  54. repeat=1):
  55. super(COCODataSet, self).__init__(
  56. dataset_dir,
  57. image_dir,
  58. anno_path,
  59. data_fields,
  60. sample_num,
  61. repeat=repeat)
  62. self.load_image_only = False
  63. self.load_semantic = False
  64. self.load_crowd = load_crowd
  65. self.allow_empty = allow_empty
  66. self.empty_ratio = empty_ratio
  67. def _sample_empty(self, records, num):
  68. # if empty_ratio is out of [0. ,1.), do not sample the records
  69. if self.empty_ratio < 0. or self.empty_ratio >= 1.:
  70. return records
  71. import random
  72. sample_num = min(
  73. int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
  74. records = random.sample(records, sample_num)
  75. return records
  76. def parse_dataset(self):
  77. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  78. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  79. assert anno_path.endswith('.json'), \
  80. 'invalid coco annotation file: ' + anno_path
  81. from pycocotools.coco import COCO
  82. coco = COCO(anno_path)
  83. img_ids = coco.getImgIds()
  84. img_ids.sort()
  85. cat_ids = coco.getCatIds()
  86. records = []
  87. empty_records = []
  88. ct = 0
  89. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  90. self.cname2cid = dict({
  91. coco.loadCats(catid)[0]['name']: clsid
  92. for catid, clsid in self.catid2clsid.items()
  93. })
  94. if 'annotations' not in coco.dataset:
  95. self.load_image_only = True
  96. logger.warning('Annotation file: {} does not contains ground truth '
  97. 'and load image information only.'.format(anno_path))
  98. for img_id in img_ids:
  99. img_anno = coco.loadImgs([img_id])[0]
  100. im_fname = img_anno['file_name']
  101. im_w = float(img_anno['width'])
  102. im_h = float(img_anno['height'])
  103. im_path = os.path.join(image_dir,
  104. im_fname) if image_dir else im_fname
  105. is_empty = False
  106. if not os.path.exists(im_path):
  107. logger.warning('Illegal image file: {}, and it will be '
  108. 'ignored'.format(im_path))
  109. continue
  110. if im_w < 0 or im_h < 0:
  111. logger.warning('Illegal width: {} or height: {} in annotation, '
  112. 'and im_id: {} will be ignored'.format(
  113. im_w, im_h, img_id))
  114. continue
  115. coco_rec = {
  116. 'im_file': im_path,
  117. 'im_id': np.array([img_id]),
  118. 'h': im_h,
  119. 'w': im_w,
  120. } if 'image' in self.data_fields else {}
  121. if not self.load_image_only:
  122. ins_anno_ids = coco.getAnnIds(
  123. imgIds=[img_id], iscrowd=None if self.load_crowd else False)
  124. instances = coco.loadAnns(ins_anno_ids)
  125. bboxes = []
  126. is_rbox_anno = False
  127. for inst in instances:
  128. # check gt bbox
  129. if inst.get('ignore', False):
  130. continue
  131. if 'bbox' not in inst.keys():
  132. continue
  133. else:
  134. if not any(np.array(inst['bbox'])):
  135. continue
  136. x1, y1, box_w, box_h = inst['bbox']
  137. x2 = x1 + box_w
  138. y2 = y1 + box_h
  139. eps = 1e-5
  140. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  141. inst['clean_bbox'] = [
  142. round(float(x), 3) for x in [x1, y1, x2, y2]
  143. ]
  144. bboxes.append(inst)
  145. else:
  146. logger.warning(
  147. 'Found an invalid bbox in annotations: im_id: {}, '
  148. 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  149. img_id, float(inst['area']), x1, y1, x2, y2))
  150. num_bbox = len(bboxes)
  151. if num_bbox <= 0 and not self.allow_empty:
  152. continue
  153. elif num_bbox <= 0:
  154. is_empty = True
  155. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  156. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  157. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  158. gt_poly = [None] * num_bbox
  159. gt_track_id = -np.ones((num_bbox, 1), dtype=np.int32)
  160. has_segmentation = False
  161. has_track_id = False
  162. for i, box in enumerate(bboxes):
  163. catid = box['category_id']
  164. gt_class[i][0] = self.catid2clsid[catid]
  165. gt_bbox[i, :] = box['clean_bbox']
  166. is_crowd[i][0] = box['iscrowd']
  167. # check RLE format
  168. if 'segmentation' in box and box['iscrowd'] == 1:
  169. gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
  170. elif 'segmentation' in box and box['segmentation']:
  171. if not np.array(
  172. box['segmentation'],
  173. dtype=object).size > 0 and not self.allow_empty:
  174. bboxes.pop(i)
  175. gt_poly.pop(i)
  176. np.delete(is_crowd, i)
  177. np.delete(gt_class, i)
  178. np.delete(gt_bbox, i)
  179. else:
  180. gt_poly[i] = box['segmentation']
  181. has_segmentation = True
  182. if 'track_id' in box:
  183. gt_track_id[i][0] = box['track_id']
  184. has_track_id = True
  185. if has_segmentation and not any(
  186. gt_poly) and not self.allow_empty:
  187. continue
  188. gt_rec = {
  189. 'is_crowd': is_crowd,
  190. 'gt_class': gt_class,
  191. 'gt_bbox': gt_bbox,
  192. 'gt_poly': gt_poly,
  193. }
  194. if has_track_id:
  195. gt_rec.update({'gt_track_id': gt_track_id})
  196. for k, v in gt_rec.items():
  197. if k in self.data_fields:
  198. coco_rec[k] = v
  199. # TODO: remove load_semantic
  200. if self.load_semantic and 'semantic' in self.data_fields:
  201. seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
  202. 'train2017', im_fname[:-3] + 'png')
  203. coco_rec.update({'semantic': seg_path})
  204. logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
  205. im_path, img_id, im_h, im_w))
  206. if is_empty:
  207. empty_records.append(coco_rec)
  208. else:
  209. records.append(coco_rec)
  210. ct += 1
  211. if self.sample_num > 0 and ct >= self.sample_num:
  212. break
  213. assert ct > 0, 'not found any coco record in %s' % (anno_path)
  214. logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
  215. format(ct, len(img_ids) - ct, anno_path))
  216. if self.allow_empty and len(empty_records) > 0:
  217. empty_records = self._sample_empty(empty_records, len(records))
  218. records += empty_records
  219. self.roidbs = records
  220. @register
  221. @serializable
  222. class SlicedCOCODataSet(COCODataSet):
  223. """Sliced COCODataSet"""
  224. def __init__(
  225. self,
  226. dataset_dir=None,
  227. image_dir=None,
  228. anno_path=None,
  229. data_fields=['image'],
  230. sample_num=-1,
  231. load_crowd=False,
  232. allow_empty=False,
  233. empty_ratio=1.,
  234. repeat=1,
  235. sliced_size=[640, 640],
  236. overlap_ratio=[0.25, 0.25], ):
  237. super(SlicedCOCODataSet, self).__init__(
  238. dataset_dir=dataset_dir,
  239. image_dir=image_dir,
  240. anno_path=anno_path,
  241. data_fields=data_fields,
  242. sample_num=sample_num,
  243. load_crowd=load_crowd,
  244. allow_empty=allow_empty,
  245. empty_ratio=empty_ratio,
  246. repeat=repeat, )
  247. self.sliced_size = sliced_size
  248. self.overlap_ratio = overlap_ratio
  249. def parse_dataset(self):
  250. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  251. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  252. assert anno_path.endswith('.json'), \
  253. 'invalid coco annotation file: ' + anno_path
  254. from pycocotools.coco import COCO
  255. coco = COCO(anno_path)
  256. img_ids = coco.getImgIds()
  257. img_ids.sort()
  258. cat_ids = coco.getCatIds()
  259. records = []
  260. empty_records = []
  261. ct = 0
  262. ct_sub = 0
  263. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  264. self.cname2cid = dict({
  265. coco.loadCats(catid)[0]['name']: clsid
  266. for catid, clsid in self.catid2clsid.items()
  267. })
  268. if 'annotations' not in coco.dataset:
  269. self.load_image_only = True
  270. logger.warning('Annotation file: {} does not contains ground truth '
  271. 'and load image information only.'.format(anno_path))
  272. try:
  273. import sahi
  274. from sahi.slicing import slice_image
  275. except Exception as e:
  276. logger.error(
  277. 'sahi not found, plaese install sahi. '
  278. 'for example: `pip install sahi`, see https://github.com/obss/sahi.'
  279. )
  280. raise e
  281. sub_img_ids = 0
  282. for img_id in img_ids:
  283. img_anno = coco.loadImgs([img_id])[0]
  284. im_fname = img_anno['file_name']
  285. im_w = float(img_anno['width'])
  286. im_h = float(img_anno['height'])
  287. im_path = os.path.join(image_dir,
  288. im_fname) if image_dir else im_fname
  289. is_empty = False
  290. if not os.path.exists(im_path):
  291. logger.warning('Illegal image file: {}, and it will be '
  292. 'ignored'.format(im_path))
  293. continue
  294. if im_w < 0 or im_h < 0:
  295. logger.warning('Illegal width: {} or height: {} in annotation, '
  296. 'and im_id: {} will be ignored'.format(
  297. im_w, im_h, img_id))
  298. continue
  299. slice_image_result = sahi.slicing.slice_image(
  300. image=im_path,
  301. slice_height=self.sliced_size[0],
  302. slice_width=self.sliced_size[1],
  303. overlap_height_ratio=self.overlap_ratio[0],
  304. overlap_width_ratio=self.overlap_ratio[1])
  305. sub_img_num = len(slice_image_result)
  306. for _ind in range(sub_img_num):
  307. im = slice_image_result.images[_ind]
  308. coco_rec = {
  309. 'image': im,
  310. 'im_id': np.array([sub_img_ids + _ind]),
  311. 'h': im.shape[0],
  312. 'w': im.shape[1],
  313. 'ori_im_id': np.array([img_id]),
  314. 'st_pix': np.array(
  315. slice_image_result.starting_pixels[_ind],
  316. dtype=np.float32),
  317. 'is_last': 1 if _ind == sub_img_num - 1 else 0,
  318. } if 'image' in self.data_fields else {}
  319. records.append(coco_rec)
  320. ct_sub += sub_img_num
  321. ct += 1
  322. if self.sample_num > 0 and ct >= self.sample_num:
  323. break
  324. assert ct > 0, 'not found any coco record in %s' % (anno_path)
  325. logger.info('{} samples and slice to {} sub_samples in file {}'.format(
  326. ct, ct_sub, anno_path))
  327. if self.allow_empty and len(empty_records) > 0:
  328. empty_records = self._sample_empty(empty_records, len(records))
  329. records += empty_records
  330. self.roidbs = records
  331. @register
  332. @serializable
  333. class SemiCOCODataSet(COCODataSet):
  334. """Semi-COCODataSet used for supervised and unsupervised dataSet"""
  335. def __init__(self,
  336. dataset_dir=None,
  337. image_dir=None,
  338. anno_path=None,
  339. data_fields=['image'],
  340. sample_num=-1,
  341. load_crowd=False,
  342. allow_empty=False,
  343. empty_ratio=1.,
  344. repeat=1,
  345. supervised=True):
  346. super(SemiCOCODataSet, self).__init__(
  347. dataset_dir, image_dir, anno_path, data_fields, sample_num,
  348. load_crowd, allow_empty, empty_ratio, repeat)
  349. self.supervised = supervised
  350. self.length = -1 # defalut -1 means all
  351. def parse_dataset(self):
  352. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  353. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  354. assert anno_path.endswith('.json'), \
  355. 'invalid coco annotation file: ' + anno_path
  356. from pycocotools.coco import COCO
  357. coco = COCO(anno_path)
  358. img_ids = coco.getImgIds()
  359. img_ids.sort()
  360. cat_ids = coco.getCatIds()
  361. records = []
  362. empty_records = []
  363. ct = 0
  364. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  365. self.cname2cid = dict({
  366. coco.loadCats(catid)[0]['name']: clsid
  367. for catid, clsid in self.catid2clsid.items()
  368. })
  369. if 'annotations' not in coco.dataset or self.supervised == False:
  370. self.load_image_only = True
  371. logger.warning('Annotation file: {} does not contains ground truth '
  372. 'and load image information only.'.format(anno_path))
  373. for img_id in img_ids:
  374. img_anno = coco.loadImgs([img_id])[0]
  375. im_fname = img_anno['file_name']
  376. im_w = float(img_anno['width'])
  377. im_h = float(img_anno['height'])
  378. im_path = os.path.join(image_dir,
  379. im_fname) if image_dir else im_fname
  380. is_empty = False
  381. if not os.path.exists(im_path):
  382. logger.warning('Illegal image file: {}, and it will be '
  383. 'ignored'.format(im_path))
  384. continue
  385. if im_w < 0 or im_h < 0:
  386. logger.warning('Illegal width: {} or height: {} in annotation, '
  387. 'and im_id: {} will be ignored'.format(
  388. im_w, im_h, img_id))
  389. continue
  390. coco_rec = {
  391. 'im_file': im_path,
  392. 'im_id': np.array([img_id]),
  393. 'h': im_h,
  394. 'w': im_w,
  395. } if 'image' in self.data_fields else {}
  396. if not self.load_image_only:
  397. ins_anno_ids = coco.getAnnIds(
  398. imgIds=[img_id], iscrowd=None if self.load_crowd else False)
  399. instances = coco.loadAnns(ins_anno_ids)
  400. bboxes = []
  401. is_rbox_anno = False
  402. for inst in instances:
  403. # check gt bbox
  404. if inst.get('ignore', False):
  405. continue
  406. if 'bbox' not in inst.keys():
  407. continue
  408. else:
  409. if not any(np.array(inst['bbox'])):
  410. continue
  411. x1, y1, box_w, box_h = inst['bbox']
  412. x2 = x1 + box_w
  413. y2 = y1 + box_h
  414. eps = 1e-5
  415. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  416. inst['clean_bbox'] = [
  417. round(float(x), 3) for x in [x1, y1, x2, y2]
  418. ]
  419. bboxes.append(inst)
  420. else:
  421. logger.warning(
  422. 'Found an invalid bbox in annotations: im_id: {}, '
  423. 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  424. img_id, float(inst['area']), x1, y1, x2, y2))
  425. num_bbox = len(bboxes)
  426. if num_bbox <= 0 and not self.allow_empty:
  427. continue
  428. elif num_bbox <= 0:
  429. is_empty = True
  430. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  431. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  432. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  433. gt_poly = [None] * num_bbox
  434. has_segmentation = False
  435. for i, box in enumerate(bboxes):
  436. catid = box['category_id']
  437. gt_class[i][0] = self.catid2clsid[catid]
  438. gt_bbox[i, :] = box['clean_bbox']
  439. is_crowd[i][0] = box['iscrowd']
  440. # check RLE format
  441. if 'segmentation' in box and box['iscrowd'] == 1:
  442. gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
  443. elif 'segmentation' in box and box['segmentation']:
  444. if not np.array(box['segmentation']
  445. ).size > 0 and not self.allow_empty:
  446. bboxes.pop(i)
  447. gt_poly.pop(i)
  448. np.delete(is_crowd, i)
  449. np.delete(gt_class, i)
  450. np.delete(gt_bbox, i)
  451. else:
  452. gt_poly[i] = box['segmentation']
  453. has_segmentation = True
  454. if has_segmentation and not any(
  455. gt_poly) and not self.allow_empty:
  456. continue
  457. gt_rec = {
  458. 'is_crowd': is_crowd,
  459. 'gt_class': gt_class,
  460. 'gt_bbox': gt_bbox,
  461. 'gt_poly': gt_poly,
  462. }
  463. for k, v in gt_rec.items():
  464. if k in self.data_fields:
  465. coco_rec[k] = v
  466. # TODO: remove load_semantic
  467. if self.load_semantic and 'semantic' in self.data_fields:
  468. seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
  469. 'train2017', im_fname[:-3] + 'png')
  470. coco_rec.update({'semantic': seg_path})
  471. logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
  472. im_path, img_id, im_h, im_w))
  473. if is_empty:
  474. empty_records.append(coco_rec)
  475. else:
  476. records.append(coco_rec)
  477. ct += 1
  478. if self.sample_num > 0 and ct >= self.sample_num:
  479. break
  480. assert ct > 0, 'not found any coco record in %s' % (anno_path)
  481. logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
  482. format(ct, len(img_ids) - ct, anno_path))
  483. if self.allow_empty and len(empty_records) > 0:
  484. empty_records = self._sample_empty(empty_records, len(records))
  485. records += empty_records
  486. self.roidbs = records
  487. if self.supervised:
  488. logger.info(f'Use {len(self.roidbs)} sup_samples data as LABELED')
  489. else:
  490. if self.length > 0: # unsup length will be decide by sup length
  491. all_roidbs = self.roidbs.copy()
  492. selected_idxs = [
  493. np.random.choice(len(all_roidbs))
  494. for _ in range(self.length)
  495. ]
  496. self.roidbs = [all_roidbs[i] for i in selected_idxs]
  497. logger.info(
  498. f'Use {len(self.roidbs)} unsup_samples data as UNLABELED')
  499. def __getitem__(self, idx):
  500. n = len(self.roidbs)
  501. if self.repeat > 1:
  502. idx %= n
  503. # data batch
  504. roidb = copy.deepcopy(self.roidbs[idx])
  505. if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
  506. idx = np.random.randint(n)
  507. roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
  508. elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
  509. idx = np.random.randint(n)
  510. roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
  511. elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
  512. roidb = [roidb, ] + [
  513. copy.deepcopy(self.roidbs[np.random.randint(n)])
  514. for _ in range(4)
  515. ]
  516. if isinstance(roidb, Sequence):
  517. for r in roidb:
  518. r['curr_iter'] = self._curr_iter
  519. else:
  520. roidb['curr_iter'] = self._curr_iter
  521. self._curr_iter += 1
  522. return self.transform(roidb)