labelme2voc.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3. import argparse
  4. import glob
  5. import os
  6. import os.path as osp
  7. import sys
  8. import imgviz
  9. import labelme
  10. try:
  11. import lxml.builder
  12. import lxml.etree
  13. except ImportError:
  14. print("Please install lxml:\n\n pip install lxml\n")
  15. sys.exit(1)
  16. def main():
  17. parser = argparse.ArgumentParser(
  18. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  19. )
  20. parser.add_argument("--input_dir", help="input annotated directory")
  21. parser.add_argument("--output_dir", help="output dataset directory")
  22. parser.add_argument("--labels", help="labels file", required=True)
  23. parser.add_argument(
  24. "--noviz", help="no visualization", action="store_true"
  25. )
  26. args = parser.parse_args()
  27. if not osp.exists(args.output_dir):
  28. os.makedirs(args.output_dir)
  29. # os.makedirs(args.output_dir)
  30. # os.makedirs(osp.join(args.output_dir, "JPEGImages"))
  31. # os.makedirs(osp.join(args.output_dir, "Annotations"))
  32. if not args.noviz:
  33. os.makedirs(osp.join(args.output_dir, "AnnotationsVisualization"))
  34. print("Creating dataset:", args.output_dir)
  35. class_names = []
  36. class_name_to_id = {}
  37. for i, line in enumerate(open(args.labels).readlines()):
  38. class_id = i
  39. class_name = line.strip()
  40. class_name_to_id[class_name] = class_id
  41. class_names.append(class_name)
  42. class_names = tuple(class_names)
  43. print("class_names:", class_names)
  44. out_class_names_file = osp.join(args.output_dir, "class_names.txt")
  45. # with open(out_class_names_file, "w") as f:
  46. # f.writelines("\n".join(class_names))
  47. # print("Saved class_names:", out_class_names_file)
  48. out_name = os.path.basename(args.output_dir) + ".xml"
  49. out_xml_file = osp.join(args.output_dir, out_name)
  50. root_maker = lxml.builder.ElementMaker()
  51. root = root_maker.data()
  52. for filename in glob.glob(osp.join(args.input_dir, "*.json")):
  53. print("Generating dataset from:", filename)
  54. label_file = labelme.LabelFile(filename=filename)
  55. base = osp.splitext(osp.basename(filename))[0]
  56. out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")
  57. if not args.noviz:
  58. out_viz_file = osp.join(
  59. args.output_dir, "AnnotationsVisualization", base + ".jpg"
  60. )
  61. img = labelme.utils.img_data_to_arr(label_file.imageData)
  62. # imgviz.io.imsave(out_img_file, img)
  63. maker = lxml.builder.ElementMaker()
  64. xml = maker.frame(
  65. maker.filename(base + ".jpg"),
  66. )
  67. points = []
  68. labels = []
  69. for shape in label_file.shapes:
  70. if shape["shape_type"] != "point":
  71. print(
  72. "Skipping shape: label={label}, "
  73. "shape_type={shape_type}".format(**shape)
  74. )
  75. continue
  76. class_name = shape["label"]
  77. class_id = class_names.index(class_name)
  78. point = shape["points"]
  79. points.append(point[0])
  80. labels.append(class_id)
  81. xml.append(
  82. maker.point(
  83. name=shape["label"],
  84. x=str(point[0][0]),
  85. y=str(point[0][1])
  86. )
  87. )
  88. root.append(xml)
  89. if not args.noviz:
  90. captions = [class_names[label] for label in labels]
  91. viz = imgviz.instances2rgb(
  92. image=img,
  93. labels=labels,
  94. bboxes=points,
  95. captions=captions,
  96. font_size=15,
  97. )
  98. imgviz.io.imsave(out_viz_file, viz)
  99. with open(out_xml_file, "wb") as f:
  100. f.write(lxml.etree.tostring(root, pretty_print=True))
  101. if __name__ == "__main__":
  102. main()