demo.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import cv2
  5. import numpy as np
  6. import glob
  7. import evaluation
  8. import os
  9. import shutil
  10. import time
  11. def args_processor():
  12. import argparse
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("-i", "--images", default="example_imgs", help="Document image folder")
  15. parser.add_argument('--model-type', default="resnet",
  16. help='model type to be used. Example : resnet32, resnet20, densenet, test')
  17. parser.add_argument("-o", "--output", default="example_imgs/output", help="The folder to store results")
  18. parser.add_argument("-rf", "--retainFactor", help="Floating point in range (0,1) specifying retain factor",
  19. default="0.85", type=float)
  20. parser.add_argument("-cm", "--cornerModel", help="Model for corner point refinement",
  21. default="../cornerModelWell")
  22. parser.add_argument("-dm", "--documentModel", help="Model for document corners detection",
  23. default="../documentModelWell")
  24. return parser.parse_args()
  25. if __name__ == "__main__":
  26. args = args_processor()
  27. corners_extractor = evaluation.corner_extractor.GetCorners(args.documentModel, args.model_type)
  28. corner_refiner = evaluation.corner_refiner.corner_finder(args.cornerModel, args.model_type)
  29. now_date = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime(time.time()))
  30. output_dir = f"{args.output}_{now_date}"
  31. shutil.rmtree(output_dir, ignore_errors=True)
  32. os.makedirs(output_dir)
  33. imgPaths = glob.glob(f"{args.images}/*.jpg")
  34. for imgPath in imgPaths:
  35. img = cv2.imread(imgPath)
  36. oImg = img
  37. e1 = cv2.getTickCount()
  38. extracted_corners = corners_extractor.get(oImg)
  39. corner_address = []
  40. # Refine the detected corners using corner refiner
  41. image_name = 0
  42. for corner in extracted_corners:
  43. image_name += 1
  44. corner_img = corner[0]
  45. refined_corner = np.array(corner_refiner.get_location(corner_img, args.retainFactor))
  46. # Converting from local co-ordinate to global co-ordinates of the image
  47. refined_corner[0] += corner[1]
  48. refined_corner[1] += corner[2]
  49. # Final results
  50. corner_address.append(refined_corner)
  51. e2 = cv2.getTickCount()
  52. print(f"Took time:{(e2 - e1)/ cv2.getTickFrequency()}")
  53. for a in range(0, len(extracted_corners)):
  54. cv2.line(oImg, tuple(corner_address[a % 4]), tuple(corner_address[(a + 1) % 4]), (255, 0, 0), 4)
  55. filename = os.path.basename(imgPath)
  56. cv2.imwrite(f"{output_dir}/{filename}", oImg)