generate_dataset.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import os
  2. from tqdm import tqdm
  3. import cv2
  4. import numpy as np
  5. import albumentations as A
  6. import random
  7. import shutil
  8. import argparse
  9. # path parameters
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--data_dir',
  12. type=str,
  13. help='Raw training data.',
  14. default="raw_data")
  15. transform = A.Compose([
  16. A.OneOf([
  17. A.ISONoise(p=0.4),
  18. A.JpegCompression(quality_lower=50, quality_upper=70,
  19. always_apply=False, p=0.8),
  20. ], p=0.6),
  21. A.OneOf([
  22. A.MotionBlur(blur_limit=11, p=.8),
  23. A.MedianBlur(blur_limit=3, p=0.75),
  24. A.GaussianBlur(blur_limit=7, p=0.75),
  25. ], p=0.8),
  26. A.OneOf([
  27. A.RandomBrightnessContrast(
  28. brightness_limit=0.3, contrast_limit=0.3, p=0.75),
  29. A.RandomShadow(num_shadows_lower=1,
  30. num_shadows_upper=18, shadow_dimension=6, p=0.85),
  31. ], p=0.8),
  32. ])
  33. def getListOfFiles(dirName):
  34. print(dirName)
  35. # create a list of file and sub directories
  36. # names in the given directory
  37. listOfFile = os.listdir(dirName)
  38. allFiles = list()
  39. # Iterate over all the entries
  40. for entry in listOfFile:
  41. allFiles.append(entry)
  42. return allFiles
  43. def ImageResize(image, factor=0.6):
  44. width = int(image.shape[1] * factor)
  45. height = int(image.shape[0] * factor)
  46. dim = (width, height)
  47. # print(image.shape)
  48. resized = cv2.resize(image, dim, interpolation=cv2.INTER_LANCZOS4)
  49. # print(resized.shape)
  50. return resized
  51. def GetOverlappingBlocks(im, M=256, N=256, Part=8):
  52. tiles = []
  53. tile = np.zeros((M, N, 3), dtype=np.uint8)
  54. #tile[:,:,2] = 255
  55. x = 0
  56. y = 0
  57. x_start = 0
  58. y_start = 0
  59. while y < im.shape[0]:
  60. while x < im.shape[1]:
  61. if(x != 0):
  62. x_start = x - int(N/Part)
  63. if(y != 0):
  64. y_start = y - int(M/Part)
  65. if(y_start+M > im.shape[0]):
  66. if(x_start+N > im.shape[1]):
  67. tile[0:im.shape[0]-y_start, 0:im.shape[1]-x_start, :] = im[y_start:im.shape[0], x_start:im.shape[1], :]
  68. else:
  69. tile[0:im.shape[0]-y_start, 0:N, :] = im[y_start:im.shape[0], x_start:x_start+N, :]
  70. else:
  71. if(x_start+N > im.shape[1]):
  72. tile[0:M, 0:im.shape[1]-x_start, :] = im[y_start:y_start+M, x_start:im.shape[1], :]
  73. else:
  74. tile[0:M, 0:N, :] = im[y_start:y_start + M, x_start:x_start+N, :]
  75. #pre_tile = cv2.cvtColor(PreProcessInput(cv2.cvtColor(tile, cv2.COLOR_RGB2BGR)), cv2.COLOR_BGR2RGB)
  76. # tiles.append(load_tf_img(pre_tile,M))
  77. # tiles.append(load_tf_img(tile,M))
  78. tiles.append(tile)
  79. tile = np.zeros((M, N, 3), dtype=np.uint8)
  80. #tile[:,:,2] = 255
  81. x = x_start + N
  82. y = y_start + M
  83. x = 0
  84. x_start = 0
  85. return tiles
  86. def GenerateTrainingBlocks(data_folder, gt_folder, dataset_path='./dataset', M=256, N=256):
  87. print(data_folder)
  88. print('Generating training blocks!!!')
  89. train_path = dataset_path + '/' + data_folder + '_Trainblocks'
  90. if not os.path.exists(train_path):
  91. os.makedirs(train_path)
  92. train_filenames = train_path + '/train_block_names.txt'
  93. f = open(train_filenames, 'w')
  94. data_path = data_folder
  95. gt_path = gt_folder
  96. # data_path = dataset_path + '/' + data_folder
  97. # gt_path = dataset_path + '/' + gt_folder
  98. print(data_path)
  99. filenames = getListOfFiles(data_path)
  100. cnt = 0
  101. print(filenames)
  102. for name in tqdm(filenames):
  103. print(name)
  104. if name == '.DS_Store':
  105. continue
  106. arr = name.split(".")
  107. gt_filename = gt_path + '/' + arr[0] + "_mask."+arr[1]
  108. in_filename = data_path + '/' + name
  109. print(gt_filename)
  110. print(in_filename)
  111. gt_image_initial = cv2.imread(gt_filename)
  112. in_image_initial = cv2.imread(in_filename)
  113. if gt_image_initial.shape[0] + gt_image_initial.shape[1] > in_image_initial.shape[0]+in_image_initial.shape[1]:
  114. gt_image_initial = cv2.resize(gt_image_initial, (in_image_initial.shape[1], in_image_initial.shape[0]))
  115. else:
  116. in_image_initial = cv2.resize(in_image_initial, (gt_image_initial.shape[1], gt_image_initial.shape[0]))
  117. print(gt_image_initial.shape, in_image_initial.shape)
  118. # cv2.imshow("img", in_image_initial)
  119. # cv2.imshow("gt", gt_image_initial)
  120. # cv2.waitKey(0)
  121. # cv2.destroyAllWindows()
  122. for scale in [0.7, 1.0, 1.4]:
  123. gt_image = ImageResize(gt_image_initial, scale)
  124. in_image = ImageResize(in_image_initial, scale)
  125. h, w, c = in_image.shape
  126. gt_img = GetOverlappingBlocks(gt_image, Part=8)
  127. in_img = GetOverlappingBlocks(in_image, Part=8)
  128. for i in range(len(gt_img)):
  129. train_img_path = train_path + '/block_' + str(cnt) + '.png'
  130. gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
  131. cv2.imwrite(train_img_path, in_img[i])
  132. # cv2.imwrite(train_img_path,PreProcessInput(in_img[i]))
  133. cv2.imwrite(gt_img_path, gt_img[i])
  134. t_name = 'block_' + str(cnt) + '.png'
  135. f.write(t_name)
  136. f.write('\n')
  137. cnt += 1
  138. Random_Block_Number_PerImage = int(len(gt_img)/5)
  139. for i in range(Random_Block_Number_PerImage):
  140. if(in_image.shape[0]-M > 1 and in_image.shape[1]-N > 1):
  141. y = random.randint(1, in_image.shape[0]-M)
  142. x = random.randint(1, in_image.shape[1]-N)
  143. in_part_img = in_image[y:y+M, x:x+N, :].copy()
  144. gt_part_img = gt_image[y:y+M, x:x+N, :].copy()
  145. train_img_path = train_path + '/block_' + str(cnt) + '.png'
  146. gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
  147. in_part_img = cv2.cvtColor(in_part_img, cv2.COLOR_BGR2RGB)
  148. augmented_image = transform(image=in_part_img)['image']
  149. augmented_image = cv2.cvtColor(
  150. augmented_image, cv2.COLOR_RGB2BGR)
  151. cv2.imwrite(train_img_path, augmented_image)
  152. cv2.imwrite(gt_img_path, gt_part_img)
  153. t_name = 'block_' + str(cnt) + '.png'
  154. f.write(t_name)
  155. f.write('\n')
  156. cnt += 1
  157. else:
  158. break
  159. in_part_img = np.zeros((M, N, 3), dtype=np.uint8)
  160. gt_part_img = np.zeros((M, N, 3), dtype=np.uint8)
  161. in_part_img[:, :, :] = 255
  162. gt_part_img[:, :, :] = 255
  163. if(in_image.shape[0]-M <= 1 and in_image.shape[1]-N > 1):
  164. y = 0
  165. x = random.randint(1, in_image.shape[1]-N)
  166. in_part_img[:h, :, :] = in_image[:, x:x+N, :].copy()
  167. gt_part_img[:h, :, :] = gt_image[:, x:x+N, :].copy()
  168. if(in_image.shape[0]-M > 1 and in_image.shape[1]-N <= 1):
  169. x = 0
  170. y = random.randint(1, in_image.shape[0]-M)
  171. in_part_img[:, :w, :] = in_image[y:y+M, :, :].copy()
  172. gt_part_img[:, :w, :] = gt_image[y:y+M, :, :].copy()
  173. train_img_path = train_path + '/block_' + str(cnt) + '.png'
  174. gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
  175. in_part_img = cv2.cvtColor(in_part_img, cv2.COLOR_BGR2RGB)
  176. augmented_image = transform(image=in_part_img)['image']
  177. augmented_image = cv2.cvtColor(
  178. augmented_image, cv2.COLOR_RGB2BGR)
  179. cv2.imwrite(train_img_path, augmented_image)
  180. cv2.imwrite(gt_img_path, gt_part_img)
  181. t_name = 'block_' + str(cnt) + '.png'
  182. f.write(t_name)
  183. f.write('\n')
  184. cnt += 1
  185. # print(cnt)
  186. f.close()
  187. print('Total number of training blocks generated: ', cnt)
  188. return train_path, train_filenames
  189. def CombineToImage(imgs,h,w,ch,Part=8):
  190. Image = np.zeros((h,w,ch),dtype=np.float32)
  191. Image_flag = np.zeros((h,w),dtype=bool)
  192. i = 0
  193. j = 0
  194. i_end = 0
  195. j_end = 0
  196. for k in range(len(imgs)):
  197. #part_img = np.copy(imgs[k,:,:,:])
  198. part_img = np.copy(imgs[k])
  199. hh,ww,cc = part_img.shape
  200. i_end = min(h,i + hh)
  201. j_end = min(w,j + ww)
  202. for m in range(hh):
  203. for n in range(ww):
  204. if(i+m<h):
  205. if(j+n<w):
  206. if(Image_flag[i+m,j+n]):
  207. Image[i+m,j+n,:] = (Image[i+m,j+n,:] + part_img[m,n,:])/2
  208. else:
  209. Image[i+m,j+n,:] = np.copy(part_img[m,n,:])
  210. Image_flag[i:i_end,j:j_end] = True
  211. j = min(w-1,j + ww - int(ww/Part))
  212. #print(i,j,w)
  213. #print(k,len(imgs))
  214. if(j_end==w):
  215. j = 0
  216. i = min(h-1,i + hh - int(hh/Part))
  217. Image = Image*255.0
  218. return Image.astype(np.uint8)
  219. if __name__ == "__main__":
  220. # img = cv2.imread("raw_data/gt/189.jpg")
  221. args = parser.parse_args()
  222. data_folder = f"{args.data_dir}/imgs"
  223. gt_folder = f"{args.data_dir}/gt"
  224. dataset = "dataset"
  225. shutil.rmtree(dataset, ignore_errors=True)
  226. os.mkdir(dataset)
  227. train_path, train_filenames = GenerateTrainingBlocks(
  228. data_folder=data_folder, gt_folder=gt_folder, dataset_path=dataset)