convert.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Reference: https://github.com/CAPTAIN-WHU/DOTA_devkit
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import json
  21. import cv2
  22. from tqdm import tqdm
  23. from multiprocessing import Pool
  24. def load_dota_info(image_dir, anno_dir, file_name, ext=None):
  25. base_name, extension = os.path.splitext(file_name)
  26. if ext and (extension != ext and extension not in ext):
  27. return None
  28. info = {'image_file': os.path.join(image_dir, file_name), 'annotation': []}
  29. anno_file = os.path.join(anno_dir, base_name + '.txt')
  30. if not os.path.exists(anno_file):
  31. return info
  32. with open(anno_file, 'r') as f:
  33. for line in f:
  34. items = line.strip().split()
  35. if (len(items) < 9):
  36. continue
  37. anno = {
  38. 'poly': list(map(float, items[:8])),
  39. 'name': items[8],
  40. 'difficult': '0' if len(items) == 9 else items[9],
  41. }
  42. info['annotation'].append(anno)
  43. return info
  44. def load_dota_infos(root_dir, num_process=8, ext=None):
  45. image_dir = os.path.join(root_dir, 'images')
  46. anno_dir = os.path.join(root_dir, 'labelTxt')
  47. data_infos = []
  48. if num_process > 1:
  49. pool = Pool(num_process)
  50. results = []
  51. for file_name in os.listdir(image_dir):
  52. results.append(
  53. pool.apply_async(load_dota_info, (image_dir, anno_dir,
  54. file_name, ext)))
  55. pool.close()
  56. pool.join()
  57. for result in results:
  58. info = result.get()
  59. if info:
  60. data_infos.append(info)
  61. else:
  62. for file_name in os.listdir(image_dir):
  63. info = load_dota_info(image_dir, anno_dir, file_name, ext)
  64. if info:
  65. data_infos.append(info)
  66. return data_infos
  67. def process_single_sample(info, image_id, class_names):
  68. image_file = info['image_file']
  69. single_image = dict()
  70. single_image['file_name'] = os.path.split(image_file)[-1]
  71. single_image['id'] = image_id
  72. image = cv2.imread(image_file)
  73. height, width, _ = image.shape
  74. single_image['width'] = width
  75. single_image['height'] = height
  76. # process annotation field
  77. single_objs = []
  78. objects = info['annotation']
  79. for obj in objects:
  80. poly, name, difficult = obj['poly'], obj['name'], obj['difficult']
  81. if difficult == '2':
  82. continue
  83. single_obj = dict()
  84. single_obj['category_id'] = class_names.index(name) + 1
  85. single_obj['segmentation'] = [poly]
  86. single_obj['iscrowd'] = 0
  87. xmin, ymin, xmax, ymax = min(poly[0::2]), min(poly[1::2]), max(poly[
  88. 0::2]), max(poly[1::2])
  89. width, height = xmax - xmin, ymax - ymin
  90. single_obj['bbox'] = [xmin, ymin, width, height]
  91. single_obj['area'] = height * width
  92. single_obj['image_id'] = image_id
  93. single_objs.append(single_obj)
  94. return (single_image, single_objs)
  95. def data_to_coco(infos, output_path, class_names, num_process):
  96. data_dict = dict()
  97. data_dict['categories'] = []
  98. for i, name in enumerate(class_names):
  99. data_dict['categories'].append({
  100. 'id': i + 1,
  101. 'name': name,
  102. 'supercategory': name
  103. })
  104. pbar = tqdm(total=len(infos), desc='data to coco')
  105. images, annotations = [], []
  106. if num_process > 1:
  107. pool = Pool(num_process)
  108. results = []
  109. for i, info in enumerate(infos):
  110. image_id = i + 1
  111. results.append(
  112. pool.apply_async(
  113. process_single_sample, (info, image_id, class_names),
  114. callback=lambda x: pbar.update()))
  115. pool.close()
  116. pool.join()
  117. for result in results:
  118. single_image, single_anno = result.get()
  119. images.append(single_image)
  120. annotations += single_anno
  121. else:
  122. for i, info in enumerate(infos):
  123. image_id = i + 1
  124. single_image, single_anno = process_single_sample(info, image_id,
  125. class_names)
  126. images.append(single_image)
  127. annotations += single_anno
  128. pbar.update()
  129. pbar.close()
  130. for i, anno in enumerate(annotations):
  131. anno['id'] = i + 1
  132. data_dict['images'] = images
  133. data_dict['annotations'] = annotations
  134. with open(output_path, 'w') as f:
  135. json.dump(data_dict, f)