gen_semi_coco.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import argparse
  17. import numpy as np
  18. def save_json(path, images, annotations, categories):
  19. new_json = {
  20. 'images': images,
  21. 'annotations': annotations,
  22. 'categories': categories,
  23. }
  24. with open(path, 'w') as f:
  25. json.dump(new_json, f)
  26. print('{} saved, with {} images and {} annotations.'.format(
  27. path, len(images), len(annotations)))
  28. def gen_semi_data(data_dir,
  29. json_file,
  30. percent=10.0,
  31. seed=1,
  32. seed_offset=0,
  33. txt_file=None):
  34. json_name = json_file.split('/')[-1].split('.')[0]
  35. json_file = os.path.join(data_dir, json_file)
  36. anno = json.load(open(json_file, 'r'))
  37. categories = anno['categories']
  38. all_images = anno['images']
  39. all_anns = anno['annotations']
  40. print(
  41. 'Totally {} images and {} annotations, about {} gts per image.'.format(
  42. len(all_images), len(all_anns), len(all_anns) / len(all_images)))
  43. if txt_file:
  44. print('Using percent {} and seed {}.'.format(percent, seed))
  45. txt_file = os.path.join(data_dir, txt_file)
  46. sup_idx = json.load(open(txt_file, 'r'))[str(percent)][str(seed)]
  47. # max(sup_idx) = 117262 # 10%, sup_idx is not image_id
  48. else:
  49. np.random.seed(seed + seed_offset)
  50. sup_len = int(percent / 100.0 * len(all_images))
  51. sup_idx = np.random.choice(
  52. range(len(all_images)), size=sup_len, replace=False)
  53. labeled_images, labeled_anns = [], []
  54. labeled_im_ids = []
  55. unlabeled_images, unlabeled_anns = [], []
  56. for i in range(len(all_images)):
  57. if i in sup_idx:
  58. labeled_im_ids.append(all_images[i]['id'])
  59. labeled_images.append(all_images[i])
  60. else:
  61. unlabeled_images.append(all_images[i])
  62. for an in all_anns:
  63. im_id = an['image_id']
  64. if im_id in labeled_im_ids:
  65. labeled_anns.append(an)
  66. else:
  67. continue
  68. save_path = '{}/{}'.format(data_dir, 'semi_annotations')
  69. if not os.path.exists(save_path):
  70. os.mkdir(save_path)
  71. sup_name = '{}.{}@{}.json'.format(json_name, seed, int(percent))
  72. sup_path = os.path.join(save_path, sup_name)
  73. save_json(sup_path, labeled_images, labeled_anns, categories)
  74. unsup_name = '{}.{}@{}-unlabeled.json'.format(json_name, seed, int(percent))
  75. unsup_path = os.path.join(save_path, unsup_name)
  76. save_json(unsup_path, unlabeled_images, unlabeled_anns, categories)
  77. if __name__ == '__main__':
  78. parser = argparse.ArgumentParser()
  79. parser.add_argument('--data_dir', type=str, default='./dataset/coco')
  80. parser.add_argument(
  81. '--json_file', type=str, default='annotations/instances_train2017.json')
  82. parser.add_argument('--percent', type=float, default=10.0)
  83. parser.add_argument('--seed', type=int, default=1)
  84. parser.add_argument('--seed_offset', type=int, default=0)
  85. parser.add_argument('--txt_file', type=str, default='COCO_supervision.txt')
  86. args = parser.parse_args()
  87. print(args)
  88. gen_semi_data(args.data_dir, args.json_file, args.percent, args.seed,
  89. args.seed_offset, args.txt_file)