self_collected_dataset_preprocess.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from genericpath import exists
  2. import glob
  3. import cv2
  4. import os
  5. import shutil
  6. import csv
  7. import random
  8. import numpy as np
  9. class bcolors:
  10. HEADER = '\033[95m'
  11. OKBLUE = '\033[94m'
  12. OKCYAN = '\033[96m'
  13. OKGREEN = '\033[92m'
  14. WARNING = '\033[93m'
  15. FAIL = '\033[91m'
  16. ENDC = '\033[0m'
  17. BOLD = '\033[1m'
  18. UNDERLINE = '\033[4m'
  19. def args_processor():
  20. import argparse
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument("-i", "--input-dir", help="dataput")
  23. parser.add_argument("-o", "--output-dir", help="Directory to store results")
  24. return parser.parse_args()
  25. def orderPoints(pts, centerPt):
  26. # size = len(pts)
  27. # centerPt = [0, 0]
  28. # for pt in pts:
  29. # centerPt[0] += pt[0] / size
  30. # centerPt[1] += pt[1] / size
  31. # cv2.circle(img, tuple(list((np.array(centerPt)).astype(int))), 2, (255, 0, 0), 2)
  32. # cv2.imshow("img", img)
  33. # cv2.waitKey()
  34. # cv2.destroyAllWindows()
  35. orderedDict = {}
  36. for pt in pts:
  37. index = -1
  38. if pt[0] < centerPt[0] and pt[1] < centerPt[1]:
  39. index = 0
  40. elif pt[0] > centerPt[0] and pt[1] < centerPt[1]:
  41. index = 1
  42. elif pt[0] < centerPt[0] and pt[1] > centerPt[1]:
  43. index = 3
  44. elif pt[0] > centerPt[0] and pt[1] > centerPt[1]:
  45. index = 2
  46. if index in orderedDict:
  47. targetKeys = [0, 1, 2, 3]
  48. for i in range(4):
  49. exists = False
  50. for key in orderedDict.keys():
  51. if key == targetKeys[i]:
  52. exists = True
  53. break
  54. if exists is False:
  55. index = targetKeys[i]
  56. break
  57. orderedDict[index] = pt
  58. orderedPts = list(dict(sorted(orderedDict.items())).values())
  59. assert len(orderedPts) == 4
  60. return orderedPts
  61. def isAvaibleImg(pts, img, centerPt):
  62. h, w = img.shape[:2]
  63. for i, pt in enumerate(pts):
  64. if pt[0] > (w - 1) or pt[0] < 1:
  65. return False
  66. if pt[1] > (h - 1) or pt[1] < 1:
  67. return False
  68. if pt[0] == centerPt[0] or pt[1] == centerPt[1]:
  69. return False
  70. for _i, _pt in enumerate(pts):
  71. if i == _i:
  72. continue
  73. if abs(pt[0] - _pt[0]) <= 3:
  74. return False
  75. if abs(pt[1] - _pt[1]) <= 3:
  76. return False
  77. return True
  78. def getCenterPt(pts):
  79. size = len(pts)
  80. centerPt = [0, 0]
  81. for pt in pts:
  82. centerPt[0] += pt[0] / size
  83. centerPt[1] += pt[1] / size
  84. return centerPt
  85. def process(imgpaths, out):
  86. for imgpath in imgpaths:
  87. csv_path = imgpath.split(".")[0] + ".csv"
  88. if os.path.isfile(csv_path) == False:
  89. continue
  90. with open(csv_path, "r") as f:
  91. reader = csv.reader(f, delimiter="\t")
  92. pts = []
  93. for i, line in enumerate(reader):
  94. split = line[0].split(" ")
  95. pt = [float(split[0]), float(split[1])]
  96. pts.append(pt)
  97. assert len(pts) == 4
  98. img = cv2.imread(imgpath)
  99. centerPt = getCenterPt(pts)
  100. if isAvaibleImg(pts, img, centerPt) is False:
  101. # print(f"{bcolors.WARNING}{imgpath} discard {bcolors.ENDC}")
  102. continue
  103. orderedPts = orderPoints(pts, centerPt)
  104. # for count, pt in enumerate(orderedPts):
  105. # cv2.putText(img, f'{count}', (int(pt[0]), int(pt[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
  106. # cv2.imshow('img',img)
  107. # cv2.waitKey()
  108. # cv2.destroyAllWindows()
  109. fileName = os.path.basename(imgpath).split(".")[0]
  110. out_imgpath = f"{out}/{fileName}.jpg"
  111. with open(f"{out_imgpath}.csv", "w") as csv_out:
  112. for pt in orderedPts:
  113. csv_out.write(f"{pt[0]} {pt[1]}")
  114. csv_out.write('\n')
  115. cv2.imwrite(out_imgpath, img)
  116. if __name__ == "__main__":
  117. args = args_processor()
  118. imgpaths = glob.glob(f"{args.input_dir}/*.jpg") + glob.glob(
  119. f"{args.input_dir}/*.png"
  120. )
  121. train_dataset_out = f"{args.output_dir}/train"
  122. test_dataset_out = f"{args.output_dir}/test"
  123. shutil.rmtree(args.output_dir, ignore_errors=True)
  124. os.mkdir(args.output_dir)
  125. os.mkdir(train_dataset_out)
  126. os.mkdir(test_dataset_out)
  127. imgpaths_num = len(imgpaths)
  128. test_num = int(imgpaths_num * 0.2)
  129. test_imgpaths = imgpaths[0:test_num]
  130. train_imgpaths = imgpaths[test_num:imgpaths_num]
  131. process(train_imgpaths, train_dataset_out)
  132. process(test_imgpaths, test_dataset_out)