evaluate.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import argparse
  5. import time
  6. import numpy as np
  7. import torch
  8. from PIL import Image
  9. import dataprocessor
  10. import evaluation
  11. from utils import utils
  12. parser = argparse.ArgumentParser(description='iCarl2.0')
  13. parser.add_argument("-i", "--data-dir", default="/Users/khurramjaved96/bg5",
  14. help="input Directory of test data")
  15. args = parser.parse_args()
  16. args.cuda = torch.cuda.is_available()
  17. if __name__ == '__main__':
  18. corners_extractor = evaluation.corner_extractor.GetCorners("../documentModelWell")
  19. corner_refiner = evaluation.corner_refiner.corner_finder("../cornerModelWell")
  20. test_set_dir = args.data_dir
  21. iou_results = []
  22. my_results = []
  23. dataset_test = dataprocessor.dataset.SmartDocDirectories(test_set_dir)
  24. for data_elem in dataset_test.myData:
  25. img_path = data_elem[0]
  26. # print(img_path)
  27. target = data_elem[1].reshape((4, 2))
  28. img_array = np.array(Image.open(img_path))
  29. computation_start_time = time.clock()
  30. extracted_corners = corners_extractor.get(img_array)
  31. temp_time = time.clock()
  32. corner_address = []
  33. # Refine the detected corners using corner refiner
  34. counter=0
  35. for corner in extracted_corners:
  36. counter+=1
  37. corner_img = corner[0]
  38. refined_corner = np.array(corner_refiner.get_location(corner_img, 0.85))
  39. # Converting from local co-ordinate to global co-ordinate of the image
  40. refined_corner[0] += corner[1]
  41. refined_corner[1] += corner[2]
  42. # Final results
  43. corner_address.append(refined_corner)
  44. computation_end_time = time.clock()
  45. print("TOTAL TIME : ", computation_end_time - computation_start_time)
  46. r2 = utils.intersection_with_corection_smart_doc_implementation(target, np.array(corner_address), img_array)
  47. r3 = utils.intersection_with_corection(target, np.array(corner_address), img_array)
  48. if r3 - r2 > 0.1:
  49. print ("Image Name", img_path)
  50. print ("Prediction", np.array(corner_address), target)
  51. 0/0
  52. assert (r2 > 0 and r2 < 1)
  53. iou_results.append(r2)
  54. my_results.append(r3)
  55. print("MEAN CORRECTED JI: ", np.mean(np.array(iou_results)))
  56. print("MEAN CORRECTED MY: ", np.mean(np.array(my_results)))
  57. print(np.mean(np.array(iou_results)))