generate_dataset.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import os
  2. from tqdm import tqdm
  3. import cv2
  4. import numpy as np
  5. import albumentations as A
  6. import random
  7. import shutil
  8. import argparse
  9. # path parameters
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--data_dir',
  12. type=str,
  13. help='Raw training data.',
  14. default="raw_data")
  15. transform = A.Compose([
  16. A.OneOf([
  17. A.ISONoise(p=0.4),
  18. A.JpegCompression(quality_lower=50, quality_upper=70,
  19. always_apply=False, p=0.8),
  20. ], p=0.6),
  21. A.OneOf([
  22. A.MotionBlur(blur_limit=10, p=.8),
  23. A.MedianBlur(blur_limit=3, p=0.75),
  24. A.GaussianBlur(blur_limit=7, p=0.75),
  25. ], p=0.8),
  26. A.OneOf([
  27. A.RandomBrightnessContrast(
  28. brightness_limit=0.3, contrast_limit=0.3, p=0.75),
  29. A.RandomShadow(num_shadows_lower=1,
  30. num_shadows_upper=18, shadow_dimension=6, p=0.85),
  31. ], p=0.8),
  32. ])
  33. def getListOfFiles(dirName):
  34. print(dirName)
  35. # create a list of file and sub directories
  36. # names in the given directory
  37. listOfFile = os.listdir(dirName)
  38. allFiles = list()
  39. # Iterate over all the entries
  40. for entry in listOfFile:
  41. allFiles.append(entry)
  42. return allFiles
  43. def ImageResize(image, factor=0.6):
  44. width = int(image.shape[1] * factor)
  45. height = int(image.shape[0] * factor)
  46. dim = (width, height)
  47. # print(image.shape)
  48. resized = cv2.resize(image, dim, interpolation=cv2.INTER_LANCZOS4)
  49. # print(resized.shape)
  50. return resized
  51. def GetOverlappingBlocks(im, M=256, N=256, Part=8):
  52. tiles = []
  53. tile = np.zeros((M, N, 3), dtype=np.uint8)
  54. #tile[:,:,2] = 255
  55. x = 0
  56. y = 0
  57. x_start = 0
  58. y_start = 0
  59. while y < im.shape[0]:
  60. while x < im.shape[1]:
  61. if(x != 0):
  62. x_start = x - int(N/Part)
  63. if(y != 0):
  64. y_start = y - int(M/Part)
  65. if(y_start+M > im.shape[0]):
  66. if(x_start+N > im.shape[1]):
  67. tile[0:im.shape[0]-y_start, 0:im.shape[1]-x_start,
  68. :] = im[y_start:im.shape[0], x_start:im.shape[1], :]
  69. else:
  70. tile[0:im.shape[0]-y_start, 0:N,
  71. :] = im[y_start:im.shape[0], x_start:x_start+N, :]
  72. else:
  73. if(x_start+N > im.shape[1]):
  74. tile[0:M, 0:im.shape[1]-x_start,
  75. :] = im[y_start:y_start+M, x_start:im.shape[1], :]
  76. else:
  77. tile[0:M, 0:N, :] = im[y_start:y_start +
  78. M, x_start:x_start+N, :]
  79. #pre_tile = cv2.cvtColor(PreProcessInput(cv2.cvtColor(tile, cv2.COLOR_RGB2BGR)), cv2.COLOR_BGR2RGB)
  80. # tiles.append(load_tf_img(pre_tile,M))
  81. # tiles.append(load_tf_img(tile,M))
  82. tiles.append(tile)
  83. tile = np.zeros((M, N, 3), dtype=np.uint8)
  84. #tile[:,:,2] = 255
  85. x = x_start + N
  86. y = y_start + M
  87. x = 0
  88. x_start = 0
  89. return tiles
  90. def GenerateTrainingBlocks(data_folder, gt_folder, dataset_path='./dataset', M=256, N=256):
  91. print(data_folder)
  92. print('Generating training blocks!!!')
  93. train_path = dataset_path + '/' + data_folder + '_Trainblocks'
  94. if not os.path.exists(train_path):
  95. os.makedirs(train_path)
  96. train_filenames = train_path + '/train_block_names.txt'
  97. f = open(train_filenames, 'w')
  98. data_path = data_folder
  99. gt_path = gt_folder
  100. # data_path = dataset_path + '/' + data_folder
  101. # gt_path = dataset_path + '/' + gt_folder
  102. print(data_path)
  103. filenames = getListOfFiles(data_path)
  104. cnt = 0
  105. print(filenames)
  106. for name in tqdm(filenames):
  107. print(name)
  108. if name == '.DS_Store':
  109. continue
  110. arr = name.split(".")
  111. gt_filename = gt_path + '/' + arr[0] + "_mask."+arr[1]
  112. in_filename = data_path + '/' + name
  113. print(gt_filename)
  114. print(in_filename)
  115. gt_image_initial = cv2.imread(gt_filename)
  116. in_image_initial = cv2.imread(in_filename)
  117. if gt_image_initial.shape[0] + gt_image_initial.shape[1] > in_image_initial.shape[0]+in_image_initial.shape[1]:
  118. gt_image_initial = cv2.resize(gt_image_initial, (in_image_initial.shape[1], in_image_initial.shape[0]))
  119. else:
  120. in_image_initial = cv2.resize(in_image_initial, (gt_image_initial.shape[1], gt_image_initial.shape[0]))
  121. print(gt_image_initial.shape, in_image_initial.shape)
  122. # cv2.imshow("img", in_image_initial)
  123. # cv2.imshow("gt", gt_image_initial)
  124. # cv2.waitKey(0)
  125. # cv2.destroyAllWindows()
  126. for scale in [0.7, 1.0, 1.4]:
  127. gt_image = ImageResize(gt_image_initial, scale)
  128. in_image = ImageResize(in_image_initial, scale)
  129. h, w, c = in_image.shape
  130. gt_img = GetOverlappingBlocks(gt_image, Part=8)
  131. in_img = GetOverlappingBlocks(in_image, Part=8)
  132. for i in range(len(gt_img)):
  133. train_img_path = train_path + '/block_' + str(cnt) + '.png'
  134. gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
  135. cv2.imwrite(train_img_path, in_img[i])
  136. # cv2.imwrite(train_img_path,PreProcessInput(in_img[i]))
  137. cv2.imwrite(gt_img_path, gt_img[i])
  138. t_name = 'block_' + str(cnt) + '.png'
  139. f.write(t_name)
  140. f.write('\n')
  141. cnt += 1
  142. Random_Block_Number_PerImage = int(len(gt_img)/5)
  143. for i in range(Random_Block_Number_PerImage):
  144. if(in_image.shape[0]-M > 1 and in_image.shape[1]-N > 1):
  145. y = random.randint(1, in_image.shape[0]-M)
  146. x = random.randint(1, in_image.shape[1]-N)
  147. in_part_img = in_image[y:y+M, x:x+N, :].copy()
  148. gt_part_img = gt_image[y:y+M, x:x+N, :].copy()
  149. train_img_path = train_path + '/block_' + str(cnt) + '.png'
  150. gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
  151. in_part_img = cv2.cvtColor(in_part_img, cv2.COLOR_BGR2RGB)
  152. augmented_image = transform(image=in_part_img)['image']
  153. augmented_image = cv2.cvtColor(
  154. augmented_image, cv2.COLOR_RGB2BGR)
  155. cv2.imwrite(train_img_path, augmented_image)
  156. cv2.imwrite(gt_img_path, gt_part_img)
  157. t_name = 'block_' + str(cnt) + '.png'
  158. f.write(t_name)
  159. f.write('\n')
  160. cnt += 1
  161. else:
  162. break
  163. in_part_img = np.zeros((M, N, 3), dtype=np.uint8)
  164. gt_part_img = np.zeros((M, N, 3), dtype=np.uint8)
  165. in_part_img[:, :, :] = 255
  166. gt_part_img[:, :, :] = 255
  167. if(in_image.shape[0]-M <= 1 and in_image.shape[1]-N > 1):
  168. y = 0
  169. x = random.randint(1, in_image.shape[1]-N)
  170. in_part_img[:h, :, :] = in_image[:, x:x+N, :].copy()
  171. gt_part_img[:h, :, :] = gt_image[:, x:x+N, :].copy()
  172. if(in_image.shape[0]-M > 1 and in_image.shape[1]-N <= 1):
  173. x = 0
  174. y = random.randint(1, in_image.shape[0]-M)
  175. in_part_img[:, :w, :] = in_image[y:y+M, :, :].copy()
  176. gt_part_img[:, :w, :] = gt_image[y:y+M, :, :].copy()
  177. train_img_path = train_path + '/block_' + str(cnt) + '.png'
  178. gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
  179. in_part_img = cv2.cvtColor(in_part_img, cv2.COLOR_BGR2RGB)
  180. augmented_image = transform(image=in_part_img)['image']
  181. augmented_image = cv2.cvtColor(
  182. augmented_image, cv2.COLOR_RGB2BGR)
  183. cv2.imwrite(train_img_path, augmented_image)
  184. cv2.imwrite(gt_img_path, gt_part_img)
  185. t_name = 'block_' + str(cnt) + '.png'
  186. f.write(t_name)
  187. f.write('\n')
  188. cnt += 1
  189. # print(cnt)
  190. f.close()
  191. print('Total number of training blocks generated: ', cnt)
  192. return train_path, train_filenames
  193. def CombineToImage(imgs,h,w,ch,Part=8):
  194. Image = np.zeros((h,w,ch),dtype=np.float32)
  195. Image_flag = np.zeros((h,w),dtype=bool)
  196. i = 0
  197. j = 0
  198. i_end = 0
  199. j_end = 0
  200. for k in range(len(imgs)):
  201. #part_img = np.copy(imgs[k,:,:,:])
  202. part_img = np.copy(imgs[k])
  203. hh,ww,cc = part_img.shape
  204. i_end = min(h,i + hh)
  205. j_end = min(w,j + ww)
  206. for m in range(hh):
  207. for n in range(ww):
  208. if(i+m<h):
  209. if(j+n<w):
  210. if(Image_flag[i+m,j+n]):
  211. Image[i+m,j+n,:] = (Image[i+m,j+n,:] + part_img[m,n,:])/2
  212. else:
  213. Image[i+m,j+n,:] = np.copy(part_img[m,n,:])
  214. Image_flag[i:i_end,j:j_end] = True
  215. j = min(w-1,j + ww - int(ww/Part))
  216. #print(i,j,w)
  217. #print(k,len(imgs))
  218. if(j_end==w):
  219. j = 0
  220. i = min(h-1,i + hh - int(hh/Part))
  221. Image = Image*255.0
  222. return Image.astype(np.uint8)
  223. if __name__ == "__main__":
  224. # img = cv2.imread("raw_data/gt/189.jpg")
  225. args = parser.parse_args()
  226. data_folder = f"{args.data_dir}/imgs"
  227. gt_folder = f"{args.data_dir}/gt"
  228. dataset = "dataset"
  229. shutil.rmtree(dataset, ignore_errors=True)
  230. os.mkdir(dataset)
  231. train_path, train_filenames = GenerateTrainingBlocks(
  232. data_folder=data_folder, gt_folder=gt_folder, dataset_path=dataset)