voc.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import xml.etree.ElementTree as ET
  17. from ppdet.core.workspace import register, serializable
  18. from .dataset import DetDataset
  19. from ppdet.utils.logger import setup_logger
  20. logger = setup_logger(__name__)
  21. @register
  22. @serializable
  23. class VOCDataSet(DetDataset):
  24. """
  25. Load dataset with PascalVOC format.
  26. Notes:
  27. `anno_path` must contains xml file and image file path for annotations.
  28. Args:
  29. dataset_dir (str): root directory for dataset.
  30. image_dir (str): directory for images.
  31. anno_path (str): voc annotation file path.
  32. data_fields (list): key name of data dictionary, at least have 'image'.
  33. sample_num (int): number of samples to load, -1 means all.
  34. label_list (str): if use_default_label is False, will load
  35. mapping between category and class index.
  36. allow_empty (bool): whether to load empty entry. False as default
  37. empty_ratio (float): the ratio of empty record number to total
  38. record's, if empty_ratio is out of [0. ,1.), do not sample the
  39. records and use all the empty entries. 1. as default
  40. repeat (int): repeat times for dataset, use in benchmark.
  41. """
  42. def __init__(self,
  43. dataset_dir=None,
  44. image_dir=None,
  45. anno_path=None,
  46. data_fields=['image'],
  47. sample_num=-1,
  48. label_list=None,
  49. allow_empty=False,
  50. empty_ratio=1.,
  51. repeat=1):
  52. super(VOCDataSet, self).__init__(
  53. dataset_dir=dataset_dir,
  54. image_dir=image_dir,
  55. anno_path=anno_path,
  56. data_fields=data_fields,
  57. sample_num=sample_num,
  58. repeat=repeat)
  59. self.label_list = label_list
  60. self.allow_empty = allow_empty
  61. self.empty_ratio = empty_ratio
  62. def _sample_empty(self, records, num):
  63. # if empty_ratio is out of [0. ,1.), do not sample the records
  64. if self.empty_ratio < 0. or self.empty_ratio >= 1.:
  65. return records
  66. import random
  67. sample_num = min(
  68. int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
  69. records = random.sample(records, sample_num)
  70. return records
  71. def parse_dataset(self, ):
  72. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  73. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  74. # mapping category name to class id
  75. # first_class:0, second_class:1, ...
  76. records = []
  77. empty_records = []
  78. ct = 0
  79. cname2cid = {}
  80. if self.label_list:
  81. label_path = os.path.join(self.dataset_dir, self.label_list)
  82. if not os.path.exists(label_path):
  83. raise ValueError("label_list {} does not exists".format(
  84. label_path))
  85. with open(label_path, 'r') as fr:
  86. label_id = 0
  87. for line in fr.readlines():
  88. cname2cid[line.strip()] = label_id
  89. label_id += 1
  90. else:
  91. cname2cid = pascalvoc_label()
  92. with open(anno_path, 'r') as fr:
  93. while True:
  94. line = fr.readline()
  95. if not line:
  96. break
  97. img_file, xml_file = [os.path.join(image_dir, x) \
  98. for x in line.strip().split()[:2]]
  99. if not os.path.exists(img_file):
  100. logger.warning(
  101. 'Illegal image file: {}, and it will be ignored'.format(
  102. img_file))
  103. continue
  104. if not os.path.isfile(xml_file):
  105. logger.warning(
  106. 'Illegal xml file: {}, and it will be ignored'.format(
  107. xml_file))
  108. continue
  109. tree = ET.parse(xml_file)
  110. if tree.find('id') is None:
  111. im_id = np.array([ct])
  112. else:
  113. im_id = np.array([int(tree.find('id').text)])
  114. objs = tree.findall('object')
  115. im_w = float(tree.find('size').find('width').text)
  116. im_h = float(tree.find('size').find('height').text)
  117. if im_w < 0 or im_h < 0:
  118. logger.warning(
  119. 'Illegal width: {} or height: {} in annotation, '
  120. 'and {} will be ignored'.format(im_w, im_h, xml_file))
  121. continue
  122. num_bbox, i = len(objs), 0
  123. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  124. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  125. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  126. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  127. for obj in objs:
  128. cname = obj.find('name').text
  129. # user dataset may not contain difficult field
  130. _difficult = obj.find('difficult')
  131. _difficult = int(
  132. _difficult.text) if _difficult is not None else 0
  133. x1 = float(obj.find('bndbox').find('xmin').text)
  134. y1 = float(obj.find('bndbox').find('ymin').text)
  135. x2 = float(obj.find('bndbox').find('xmax').text)
  136. y2 = float(obj.find('bndbox').find('ymax').text)
  137. x1 = max(0, x1)
  138. y1 = max(0, y1)
  139. x2 = min(im_w - 1, x2)
  140. y2 = min(im_h - 1, y2)
  141. if x2 > x1 and y2 > y1:
  142. gt_bbox[i, :] = [x1, y1, x2, y2]
  143. gt_class[i, 0] = cname2cid[cname]
  144. gt_score[i, 0] = 1.
  145. difficult[i, 0] = _difficult
  146. i += 1
  147. else:
  148. logger.warning(
  149. 'Found an invalid bbox in annotations: xml_file: {}'
  150. ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  151. xml_file, x1, y1, x2, y2))
  152. gt_bbox = gt_bbox[:i, :]
  153. gt_class = gt_class[:i, :]
  154. gt_score = gt_score[:i, :]
  155. difficult = difficult[:i, :]
  156. voc_rec = {
  157. 'im_file': img_file,
  158. 'im_id': im_id,
  159. 'h': im_h,
  160. 'w': im_w
  161. } if 'image' in self.data_fields else {}
  162. gt_rec = {
  163. 'gt_class': gt_class,
  164. 'gt_score': gt_score,
  165. 'gt_bbox': gt_bbox,
  166. 'difficult': difficult
  167. }
  168. for k, v in gt_rec.items():
  169. if k in self.data_fields:
  170. voc_rec[k] = v
  171. if len(objs) == 0:
  172. empty_records.append(voc_rec)
  173. else:
  174. records.append(voc_rec)
  175. ct += 1
  176. if self.sample_num > 0 and ct >= self.sample_num:
  177. break
  178. assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
  179. logger.debug('{} samples in file {}'.format(ct, anno_path))
  180. if self.allow_empty and len(empty_records) > 0:
  181. empty_records = self._sample_empty(empty_records, len(records))
  182. records += empty_records
  183. self.roidbs, self.cname2cid = records, cname2cid
  184. def get_label_list(self):
  185. return os.path.join(self.dataset_dir, self.label_list)
  186. def pascalvoc_label():
  187. labels_map = {
  188. 'aeroplane': 0,
  189. 'bicycle': 1,
  190. 'bird': 2,
  191. 'boat': 3,
  192. 'bottle': 4,
  193. 'bus': 5,
  194. 'car': 6,
  195. 'cat': 7,
  196. 'chair': 8,
  197. 'cow': 9,
  198. 'diningtable': 10,
  199. 'dog': 11,
  200. 'horse': 12,
  201. 'motorbike': 13,
  202. 'person': 14,
  203. 'pottedplant': 15,
  204. 'sheep': 16,
  205. 'sofa': 17,
  206. 'train': 18,
  207. 'tvmonitor': 19
  208. }
  209. return labels_map