瀏覽代碼

综合 - init git

WangChao 2 年之前
當前提交
39934f60df
共有 75 個文件被更改,包括 4569 次插入0 次删除
  1. 6 0
      .gitignore
  2. 26 0
      Recursive-CNNs/.vscode/launch.json
  3. 201 0
      Recursive-CNNs/LICENSE
  4. 89 0
      Recursive-CNNs/README.md
  5. 71 0
      Recursive-CNNs/corner_data_generator.py
  6. 67 0
      Recursive-CNNs/data_augmentor/augmentData.py
  7. 63 0
      Recursive-CNNs/data_augmentor/cornerData.py
  8. 32 0
      Recursive-CNNs/data_augmentor/label.py
  9. 21 0
      Recursive-CNNs/data_augmentor/testData.py
  10. 4 0
      Recursive-CNNs/dataprocessor/__init__.py
  11. 二進制
      Recursive-CNNs/dataprocessor/__pycache__/__init__.cpython-38.pyc
  12. 二進制
      Recursive-CNNs/dataprocessor/__pycache__/dataloaders.cpython-38.pyc
  13. 二進制
      Recursive-CNNs/dataprocessor/__pycache__/dataset.cpython-38.pyc
  14. 二進制
      Recursive-CNNs/dataprocessor/__pycache__/datasetfactory.cpython-38.pyc
  15. 二進制
      Recursive-CNNs/dataprocessor/__pycache__/loaderfactory.cpython-38.pyc
  16. 119 0
      Recursive-CNNs/dataprocessor/dataloaders.py
  17. 327 0
      Recursive-CNNs/dataprocessor/dataset.py
  18. 19 0
      Recursive-CNNs/dataprocessor/datasetfactory.py
  19. 21 0
      Recursive-CNNs/dataprocessor/loaderfactory.py
  20. 65 0
      Recursive-CNNs/demo.py
  21. 87 0
      Recursive-CNNs/document_data_generator.py
  22. 69 0
      Recursive-CNNs/evaluate.py
  23. 6 0
      Recursive-CNNs/evaluation/__init__.py
  24. 72 0
      Recursive-CNNs/evaluation/corner_extractor.py
  25. 84 0
      Recursive-CNNs/evaluation/corner_refiner.py
  26. 1 0
      Recursive-CNNs/experiment/__init__.py
  27. 二進制
      Recursive-CNNs/experiment/__pycache__/__init__.cpython-38.pyc
  28. 二進制
      Recursive-CNNs/experiment/__pycache__/experiment.cpython-38.pyc
  29. 53 0
      Recursive-CNNs/experiment/experiment.py
  30. 61 0
      Recursive-CNNs/mobile_model_converter.py
  31. 1 0
      Recursive-CNNs/model/__init__.py
  32. 二進制
      Recursive-CNNs/model/__pycache__/__init__.cpython-38.pyc
  33. 二進制
      Recursive-CNNs/model/__pycache__/cornerModel.cpython-38.pyc
  34. 二進制
      Recursive-CNNs/model/__pycache__/modelfactory.cpython-38.pyc
  35. 二進制
      Recursive-CNNs/model/__pycache__/res_utils.cpython-38.pyc
  36. 二進制
      Recursive-CNNs/model/__pycache__/resnet32.cpython-38.pyc
  37. 78 0
      Recursive-CNNs/model/cornerModel.py
  38. 37 0
      Recursive-CNNs/model/modelfactory.py
  39. 37 0
      Recursive-CNNs/model/res_utils.py
  40. 184 0
      Recursive-CNNs/model/resnet32.py
  41. 1 0
      Recursive-CNNs/plotter/__init__.py
  42. 88 0
      Recursive-CNNs/plotter/plotter.py
  43. 54 0
      Recursive-CNNs/requirements.txt
  44. 二進制
      Recursive-CNNs/results/qualitativeResults.jpg
  45. 142 0
      Recursive-CNNs/self_collected_dataset_preprocess.py
  46. 40 0
      Recursive-CNNs/smartdoc_data_processor/video_to_image.py
  47. 101 0
      Recursive-CNNs/sythetic_doc.py
  48. 146 0
      Recursive-CNNs/train_model.py
  49. 35 0
      Recursive-CNNs/train_model.sh
  50. 149 0
      Recursive-CNNs/train_seg_model.py
  51. 2 0
      Recursive-CNNs/trainer/__init__.py
  52. 二進制
      Recursive-CNNs/trainer/__pycache__/__init__.cpython-38.pyc
  53. 二進制
      Recursive-CNNs/trainer/__pycache__/evaluator.cpython-38.pyc
  54. 二進制
      Recursive-CNNs/trainer/__pycache__/trainer.cpython-38.pyc
  55. 63 0
      Recursive-CNNs/trainer/evaluator.py
  56. 110 0
      Recursive-CNNs/trainer/trainer.py
  57. 2 0
      Recursive-CNNs/utils/__init__.py
  58. 二進制
      Recursive-CNNs/utils/__pycache__/__init__.cpython-38.pyc
  59. 二進制
      Recursive-CNNs/utils/__pycache__/colorer.cpython-38.pyc
  60. 二進制
      Recursive-CNNs/utils/__pycache__/utils.cpython-38.pyc
  61. 114 0
      Recursive-CNNs/utils/colorer.py
  62. 296 0
      Recursive-CNNs/utils/utils.py
  63. 54 0
      doc_clean_up/.vscode/launch.json
  64. 96 0
      doc_clean_up/MS_SSIM_L1_loss.py
  65. 75 0
      doc_clean_up/convert_model_to_tflite.py
  66. 80 0
      doc_clean_up/dataset.py
  67. 265 0
      doc_clean_up/generate_dataset.py
  68. 114 0
      doc_clean_up/infer.py
  69. 65 0
      doc_clean_up/loss.py
  70. 124 0
      doc_clean_up/model.py
  71. 66 0
      doc_clean_up/requirements.txt
  72. 22 0
      doc_clean_up/tflite_infer.py
  73. 295 0
      doc_clean_up/train.py
  74. 69 0
      doc_clean_up/vgg19.py
  75. 二進制
      document/交接文档.docx

+ 6 - 0
.gitignore

@@ -0,0 +1,6 @@
+doc_clean_up/runs
+doc_clean_up/__pycache__
+doc_clean_up/output
+doc_clean_up/raw_data
+doc_clean_up/dataset
+Recursive-CNNs/dataset

+ 26 - 0
Recursive-CNNs/.vscode/launch.json

@@ -0,0 +1,26 @@
+{
+    // Use IntelliSense to learn about possible attributes.
+    // Hover to view descriptions of existing attributes.
+    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "name": "Train Doc Model",
+            "type": "python",
+            "request": "launch",
+            "program": "${workspaceFolder}/train_model.py",
+            "args": [
+                "--data-dirs", "dataset/selfCollectedData_DocCyclic", "dataset/smartdocData_DocTrainC",
+                "--validation-dirs", "dataset/smartDocData_DocTestC",
+                "--name", "DocModel",
+                "--lr", "0.5",
+                "--batch-size", "16",
+                "--schedule", "20", "30", "35",
+                "--model-type", "resnet",
+                "--loader", "ram",
+            ],
+            "console": "integratedTerminal",
+            "justMyCode": true
+        }
+    ]
+}

+ 201 - 0
Recursive-CNNs/LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright {yyyy} {name of copyright owner}
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

File diff suppressed because it is too large
+ 89 - 0
Recursive-CNNs/README.md


+ 71 - 0
Recursive-CNNs/corner_data_generator.py

@@ -0,0 +1,71 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import os
+import cv2
+import numpy as np
+
+import dataprocessor
+from utils import utils
+
+
+def args_processor():
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--input-dir", help="Path to data files (Extract images using video_to_image.py first")
+    parser.add_argument("-o", "--output-dir", help="Directory to store results")
+    parser.add_argument("-v", "--visualize", help="Draw the point on the corner", default=False, type=bool)
+    parser.add_argument("--dataset", default="smartdoc", help="'smartdoc' or 'selfcollected' dataset")
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = args_processor()
+    input_directory = args.input_dir
+    if not os.path.isdir(args.output_dir):
+        os.mkdir(args.output_dir)
+    import csv
+
+    # Dataset iterator
+    if args.dataset=="smartdoc":
+        dataset_test = dataprocessor.dataset.SmartDocDirectories(input_directory)
+    elif args.dataset=="selfcollected":
+        dataset_test = dataprocessor.dataset.SelfCollectedDataset(input_directory)
+    else:
+        print ("Incorrect dataset type; please choose between smartdoc or selfcollected")
+        assert(False)
+    with open(os.path.join(args.output_dir, 'gt.csv'), 'a') as csvfile:
+        spamwriter = csv.writer(csvfile, delimiter=',',
+                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
+        # Counter for file naming
+        counter = 0
+        for data_elem in dataset_test.myData:
+
+            img_path = data_elem[0]
+            target = data_elem[1].reshape((4, 2))
+            img = cv2.imread(img_path)
+
+            if args.dataset=="selfcollected":
+                target = target / (img.shape[1], img.shape[0])
+                target = target * (1920, 1920)
+                img = cv2.resize(img, (1920, 1920))
+
+            corner_cords = target
+
+            for angle in range(0, 1, 90):
+                img_rotate, gt_rotate = utils.rotate(img, corner_cords, angle)
+                for random_crop in range(0, 1):
+                    img_list, gt_list = utils.get_corners(img_rotate, gt_rotate)
+                    for a in range(0, 4):
+                        counter += 1
+                        f_name = str(counter).zfill(8)
+                        print(gt_list[a])
+                        gt_store = list(np.array(gt_list[a]) / (300, 300))
+                        img_store = cv2.resize(img_list[a], (64, 64))
+                        if args.visualize:
+                            cv2.circle(img_store, tuple(list((np.array(gt_store)*64).astype(int))), 2, (255, 0, 0), 2)
+
+                        cv2.imwrite(os.path.join(args.output_dir, f_name + ".jpg"),
+                                    img_store, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
+                        spamwriter.writerow((f_name + ".jpg", tuple(gt_store)))

+ 67 - 0
Recursive-CNNs/data_augmentor/augmentData.py

@@ -0,0 +1,67 @@
+import os
+
+import cv2
+import numpy as np
+
+import utils
+
+
+def argsProcessor():
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--dataPath", help="DataPath")
+    parser.add_argument("-o", "--outputFiles", help="outputFiles", default="bar")
+    return parser.parse_args()
+
+args = argsProcessor()
+
+output_dir = args.outputFiles
+if (not os.path.isdir(output_dir)):
+    os.mkdir(output_dir)
+
+dir = args.dataPath
+import csv
+
+with open(output_dir+"/gt.csv", 'a') as csvfile:
+    spamwriter_1 = csv.writer(csvfile, delimiter=',',
+                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
+    for image in os.listdir(dir):
+        if image.endswith("jpg") or image.endswith("JPG"):
+            if os.path.isfile(dir+"/"+image+".csv"):
+                with open(dir+"/"+image+ ".csv", 'r') as csvfile:
+                    spamwriter = csv.reader(csvfile, delimiter=' ',
+                                            quotechar='|', quoting=csv.QUOTE_MINIMAL)
+                    img = cv2.imread(dir +"/"+ image)
+                    print (image)
+                    gt= []
+                    for row in spamwriter:
+                        gt.append(row)
+                        # img = cv2.circle(img, (int(float(row[0])), int(float(row[1]))), 2,(255,0,0),90)
+                    gt =np.array(gt).astype(np.float32)
+                    gt = gt / (img.shape[1], img.shape[0])
+                    gt = gt * (1080, 1080)
+                    img = cv2.resize(img, (1080, 1080))
+
+
+                    print (gt)
+
+                    for angle in range(0,271,90):
+                        img_rotate, gt_rotate = utils.rotate(img, gt, angle)
+                        for random_crop in range(0,16):
+                            img_crop, gt_crop = utils.random_crop(img_rotate, gt_rotate)
+                            mah_size = img_crop.shape
+                            img_crop = cv2.resize(img_crop, (64, 64))
+                            gt_crop = np.array(gt_crop)
+
+                            # gt_crop = gt_crop*(1.0 / mah_size[1],1.0 / mah_size[0])
+
+                            # for a in range(0,4):
+                            # no=0
+                            # for a in range(0,4):
+                            #     no+=1
+                            #     cv2.circle(img_crop, tuple(((gt_crop[a]*64).astype(int))), 2,(255-no*60,no*60,0),9)
+                            # # # cv2.imwrite("asda.jpg", img)
+
+                            cv2.imwrite(output_dir + "/" +str(angle)+str(random_crop)+ image, img_crop)
+                            spamwriter_1.writerow((str(angle)+str(random_crop)+ image, tuple(list(gt_crop))))
+

+ 63 - 0
Recursive-CNNs/data_augmentor/cornerData.py

@@ -0,0 +1,63 @@
+import os
+
+import cv2
+import numpy as np
+
+import utils
+
+def argsProcessor():
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--dataPath", help="DataPath")
+    parser.add_argument("-o", "--outputFiles", help="outputFiles", default="bar")
+    return parser.parse_args()
+
+args = argsProcessor()
+
+output_dir = args.outputFiles
+if (not os.path.isdir(output_dir)):
+    os.mkdir(output_dir)
+
+dir = args.dataPath
+import csv
+
+with open(output_dir+"/gt.csv", 'a') as csvfile:
+    spamwriter_1 = csv.writer(csvfile, delimiter=',',
+                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
+    for image in os.listdir(dir):
+        if image.endswith("jpg"):
+            if os.path.isfile(dir+"/"+image+".csv"):
+                with open(dir+"/"+image+ ".csv", 'r') as csvfile:
+                    spamwriter = csv.reader(csvfile, delimiter=' ',
+                                            quotechar='|', quoting=csv.QUOTE_MINIMAL)
+                    img = cv2.imread(dir +"/"+ image)
+                    print (image)
+                    gt= []
+                    for row in spamwriter:
+                        gt.append(row)
+                        # img = cv2.circle(img, (int(float(row[0])), int(float(row[1]))), 2,(255,0,0),90)
+                    gt =np.array(gt).astype(np.float32)
+
+
+                    # print gt
+                    gt = gt / (img.shape[1], img.shape[0])
+
+                    gt = gt * (1080, 1080)
+
+                    img = cv2.resize(img, ( 1080,1080))
+                    # for a in range(0,4):
+                    #     img = cv2.circle(img, tuple((gt[a].astype(int))), 2, (255, 0, 0), 9)
+                    # cv2.imwrite("asda.jpg", img)
+                    # 0/0
+                    for angle in range(0,271,90):
+                        img_rotate, gt_rotate = utils.rotate(img, gt, angle)
+                        for random_crop in range(0,16):
+                            img_list, gt_list = utils.getCorners(img_rotate, gt_rotate)
+                            for a in range(0,4):
+                                print (gt_list[a])
+                                gt_store = list(np.array(gt_list[a])/(300,300))
+                                img_store = cv2.resize(img_list[a], (64,64))
+                                print (tuple(list(np.array(gt_store)*64)))
+                                # cv2.circle(img_store, tuple(list((np.array(gt_store)*64).astype(int))), 2, (255, 0, 0), 2)
+                                cv2.imwrite( output_dir+"/"+image + str(angle) +str(random_crop) + str(a) +".jpg", img_store)
+                                spamwriter_1.writerow(( image + str(angle) +str(random_crop) + str(a) +".jpg", tuple(gt_store)))

+ 32 - 0
Recursive-CNNs/data_augmentor/label.py

@@ -0,0 +1,32 @@
+import os
+
+import matplotlib.image as mpimg
+import matplotlib.pyplot as plt
+
+current_file = None
+
+
+def onclick(event):
+    if event.dblclick:
+        print('button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
+              (event.button, event.x, event.y, event.xdata, event.ydata))
+        import csv
+        with open(current_file + ".csv", 'a') as csvfile:
+            spamwriter = csv.writer(csvfile, delimiter=' ',
+                                    quotechar='|', quoting=csv.QUOTE_MINIMAL)
+            spamwriter.writerow([str(event.xdata), str(event.ydata)])
+
+
+dir = "../data1/"
+for image in os.listdir(dir):
+    if image.endswith("jpg") or image.endswith("JPG"):
+        if os.path.isfile(dir + image + ".csv"):
+            pass
+        else:
+            fig = plt.figure()
+            cid = fig.canvas.mpl_connect('button_press_event', onclick)
+            print(dir + image)
+            current_file = dir + image
+            img = mpimg.imread(dir + image)
+            plt.imshow(img)
+            plt.show()

+ 21 - 0
Recursive-CNNs/data_augmentor/testData.py

@@ -0,0 +1,21 @@
+import os
+import numpy as np
+import cv2
+import csv
+
+dir = "../data1/"
+for image in os.listdir(dir):
+    if image.endswith("jpg") or image.endswith("JPG"):
+        if os.path.isfile(dir+image+".csv"):
+            with open(dir+image+ ".csv", 'r') as csvfile:
+                spamwriter = csv.reader(csvfile, delimiter=' ',
+                                        quotechar='|', quoting=csv.QUOTE_MINIMAL)
+                img = cv2.imread(dir + image)
+                no = 0
+                for row in spamwriter:
+                    no+=1
+                    print (row)
+                    img = cv2.circle(img, (int(float(row[0])), int(float(row[1]))), 2,(255-no*60,no*60,0),90)
+                img = cv2.resize(img, (300,300))
+                cv2.imshow("a",img)
+                cv2.waitKey(0)

+ 4 - 0
Recursive-CNNs/dataprocessor/__init__.py

@@ -0,0 +1,4 @@
+from dataprocessor.datasetfactory import *
+from dataprocessor.dataloaders import *
+from dataprocessor.loaderfactory import *
+from dataprocessor.dataset import *

二進制
Recursive-CNNs/dataprocessor/__pycache__/__init__.cpython-38.pyc


二進制
Recursive-CNNs/dataprocessor/__pycache__/dataloaders.cpython-38.pyc


二進制
Recursive-CNNs/dataprocessor/__pycache__/dataset.cpython-38.pyc


二進制
Recursive-CNNs/dataprocessor/__pycache__/datasetfactory.cpython-38.pyc


二進制
Recursive-CNNs/dataprocessor/__pycache__/loaderfactory.cpython-38.pyc


+ 119 - 0
Recursive-CNNs/dataprocessor/dataloaders.py

@@ -0,0 +1,119 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import logging
+
+import PIL
+import torch.utils.data as td
+import tqdm
+from PIL import Image
+import cv2
+
+logger = logging.getLogger('iCARL')
+
+
+class HddLoader(td.Dataset):
+    def __init__(self, data, transform=None, cuda=False):
+        self.data = data
+
+        self.transform = transform
+        self.cuda = cuda
+        self.len = len(data[0])
+
+    def __len__(self):
+        return self.len
+
+    def __getitem__(self, index):
+        '''
+        Replacing this with a more efficient implemnetation selection; removing c
+        :param index: 
+        :return: 
+        '''
+        assert (index < len(self.data[0]))
+        assert (index < self.len)
+        img = Image.open(self.data[0][index])
+        target = self.data[1][index]
+        if self.transform is not None:
+            img = self.transform(img)
+
+        return img, target
+
+class RamLoader(td.Dataset):
+    def __init__(self, data, transform=None, cuda=False):
+        self.data = data
+
+        self.transform = transform
+        self.cuda = cuda
+        self.len = len(data[0])
+        self.loadInRam()
+
+    def loadInRam(self):
+        self.loaded_data = []
+        logger.info("Loading data in RAM")
+        for i in tqdm.tqdm(self.data[0]):
+            # img = Image.open(i)
+            img = cv2.imread(i)
+            if self.transform is not None:
+                img = self.transform(img)
+            self.loaded_data.append(img)
+
+    def __len__(self):
+        return self.len
+
+    def __getitem__(self, index):
+        '''
+        Replacing this with a more efficient implemnetation selection; removing c
+        :param index: 
+        :return: 
+        '''
+        assert (index < len(self.data[0]))
+        assert (index < self.len)
+        target = self.data[1][index]
+        img = self.loaded_data[index]
+        return img, target
+
+
+
+class SingleFolderLoaderResized(td.Dataset):
+    '''
+    This loader class decodes all the images into tensors; this removes the decoding time.
+    '''
+
+    def __init__(self, data, transform=None, cuda=False):
+
+        self.data = data
+
+        self.transform = transform
+        self.cuda = cuda
+        self.len = len(data)
+        self.decodeImages()
+
+    def decodeImages(self):
+        self.loaded_data = []
+        logger.info("Resizing Images")
+        for i in tqdm.tqdm(self.data):
+            i = i[0]
+            img = Image.open(i)
+            img = img.resize((32, 32), PIL.Image.ANTIALIAS)
+            img.save(i)
+
+    def __len__(self):
+        return self.len
+
+    def __getitem__(self, index):
+        '''
+        Replacing this with a more efficient implemnetation selection; removing c
+        :param index: 
+        :return: 
+        '''
+        assert (index < len(self.data))
+        assert (index < self.len)
+
+        img = Image.open(self.data[index][0])
+        target = self.data[index][1]
+        if self.transform is not None:
+            img = self.transform(img)
+
+        return img, target
+

+ 327 - 0
Recursive-CNNs/dataprocessor/dataset.py

@@ -0,0 +1,327 @@
+""" Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca """
+
+import imgaug.augmenters as iaa
+import csv
+import logging
+import os
+import xml.etree.ElementTree as ET
+
+import numpy as np
+from torchvision import transforms
+
+import utils.utils as utils
+
+# To incdude a new Dataset, inherit from Dataset and add all the Dataset specific parameters here.
+# Goal : Remove any data specific parameters from the rest of the code
+
+logger = logging.getLogger("iCARL")
+
+
+class Dataset:
+    """
+    Base class to reprenent a Dataset
+    """
+
+    def __init__(self, name):
+        self.name = name
+        self.data = []
+        self.labels = []
+
+
+def getTransformsByImgaug():
+    return iaa.Sequential(
+        [
+            iaa.Resize(32),
+            # Add blur
+            iaa.Sometimes(
+                0.05,
+                iaa.OneOf(
+                    [
+                        iaa.GaussianBlur(
+                            (0, 3.0)
+                        ),  # blur images with a sigma between 0 and 3.0
+                        iaa.AverageBlur(
+                            k=(2, 11)
+                        ),  # blur image using local means with kernel sizes between 2 and 7
+                        iaa.MedianBlur(
+                            k=(3, 11)
+                        ),  # blur image using local medians with kernel sizes between 2 and 7
+                        iaa.MotionBlur(k=15, angle=[-45, 45]),
+                    ]
+                ),
+            ),
+            # Add color
+            iaa.Sometimes(
+                0.05,
+                iaa.OneOf(
+                    [
+                        iaa.WithHueAndSaturation(iaa.WithChannels(0, iaa.Add((0, 50)))),
+                        iaa.AddToBrightness((-30, 30)),
+                        iaa.MultiplyBrightness((0.5, 1.5)),
+                        iaa.AddToHueAndSaturation((-50, 50), per_channel=True),
+                        iaa.Grayscale(alpha=(0.0, 1.0)),
+                        iaa.ChangeColorTemperature((1100, 10000)),
+                        iaa.KMeansColorQuantization(),
+                    ]
+                ),
+            ),
+            # Add wether
+            iaa.Sometimes(
+                0.05,
+                iaa.OneOf(
+                    [
+                        iaa.Clouds(),
+                        iaa.Fog(),
+                        iaa.Snowflakes(flake_size=(0.1, 0.4), speed=(0.01, 0.05)),
+                        iaa.Rain(speed=(0.1, 0.3)),
+                    ]
+                ),
+            ),
+            # Add contrast
+            iaa.Sometimes(
+                0.05,
+                iaa.OneOf(
+                    [
+                        iaa.GammaContrast((0.5, 2.0)),
+                        iaa.GammaContrast((0.5, 2.0), per_channel=True),
+                        iaa.SigmoidContrast(gain=(3, 10), cutoff=(0.4, 0.6)),
+                        iaa.SigmoidContrast(
+                            gain=(3, 10), cutoff=(0.4, 0.6), per_channel=True
+                        ),
+                        iaa.LogContrast(gain=(0.6, 1.4)),
+                        iaa.LogContrast(gain=(0.6, 1.4), per_channel=True),
+                        iaa.LinearContrast((0.4, 1.6)),
+                        iaa.LinearContrast((0.4, 1.6), per_channel=True),
+                        iaa.AllChannelsCLAHE(),
+                        iaa.AllChannelsCLAHE(clip_limit=(1, 10)),
+                        iaa.AllChannelsCLAHE(clip_limit=(1, 10), per_channel=True),
+                        iaa.Alpha((0.0, 1.0), iaa.HistogramEqualization()),
+                        iaa.Alpha((0.0, 1.0), iaa.AllChannelsHistogramEqualization()),
+                        iaa.AllChannelsHistogramEqualization(),
+                    ]
+                ),
+            )
+        ]
+    ).augment_image
+
+
+class SmartDoc(Dataset):
+    """
+    Class to include MNIST specific details
+    """
+
+    def __init__(self, directory="data"):
+        super().__init__("smartdoc")
+        self.data = []
+        self.labels = []
+        for d in directory:
+            self.directory = d
+            self.train_transform = transforms.Compose(
+                [
+                    getTransformsByImgaug(),
+                    #     transforms.Resize([32, 32]),
+                    #    transforms.ColorJitter(1.5, 1.5, 0.9, 0.5),
+                    transforms.ToTensor(),
+                ]
+            )
+
+            self.test_transform = transforms.Compose(
+                [
+                    iaa.Sequential(
+                        [
+                            iaa.Resize(32),
+                        ]
+                    ).augment_image,
+                    transforms.ToTensor(),
+                ]
+            )
+
+            logger.info("Pass train/test data paths here")
+
+            self.classes_list = {}
+
+            file_names = []
+            print(self.directory, "gt.csv")
+            with open(os.path.join(self.directory, "gt.csv"), "r") as csvfile:
+                spamreader = csv.reader(
+                    csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL
+                )
+                import ast
+
+                for row in spamreader:
+                    file_names.append(row[0])
+                    self.data.append(os.path.join(self.directory, row[0]))
+                    test = row[1].replace("array", "")
+                    self.labels.append((ast.literal_eval(test)))
+        self.labels = np.array(self.labels)
+
+        self.labels = np.reshape(self.labels, (-1, 8))
+        logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
+        logger.debug("Data shape %s", str(len(self.data)))
+
+        self.myData = [self.data, self.labels]
+
+
+class SmartDocDirectories(Dataset):
+    """
+    Class to include MNIST specific details
+    """
+
+    def __init__(self, directory="data"):
+        super().__init__("smartdoc")
+        self.data = []
+        self.labels = []
+
+        for folder in os.listdir(directory):
+            if os.path.isdir(directory + "/" + folder):
+                for file in os.listdir(directory + "/" + folder):
+                    images_dir = directory + "/" + folder + "/" + file
+                    if os.path.isdir(images_dir):
+
+                        list_gt = []
+                        tree = ET.parse(images_dir + "/" + file + ".gt")
+                        root = tree.getroot()
+                        for a in root.iter("frame"):
+                            list_gt.append(a)
+
+                        im_no = 0
+                        for image in os.listdir(images_dir):
+                            if image.endswith(".jpg"):
+                                # print(im_no)
+                                im_no += 1
+
+                                # Now we have opened the file and GT. Write code to create multiple files and scale gt
+                                list_of_points = {}
+
+                                # img = cv2.imread(images_dir + "/" + image)
+                                self.data.append(os.path.join(images_dir, image))
+
+                                for point in list_gt[int(float(image[0:-4])) - 1].iter(
+                                    "point"
+                                ):
+                                    myDict = point.attrib
+
+                                    list_of_points[myDict["name"]] = (
+                                        int(float(myDict["x"])),
+                                        int(float(myDict["y"])),
+                                    )
+
+                                ground_truth = np.asarray(
+                                    (
+                                        list_of_points["tl"],
+                                        list_of_points["tr"],
+                                        list_of_points["br"],
+                                        list_of_points["bl"],
+                                    )
+                                )
+                                ground_truth = utils.sort_gt(ground_truth)
+                                self.labels.append(ground_truth)
+
+        self.labels = np.array(self.labels)
+
+        self.labels = np.reshape(self.labels, (-1, 8))
+        logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
+        logger.debug("Data shape %s", str(len(self.data)))
+
+        self.myData = []
+        for a in range(len(self.data)):
+            self.myData.append([self.data[a], self.labels[a]])
+
+
+class SelfCollectedDataset(Dataset):
+    """
+    Class to include MNIST specific details
+    """
+
+    def __init__(self, directory="data"):
+        super().__init__("smartdoc")
+        self.data = []
+        self.labels = []
+
+        for image in os.listdir(directory):
+            # print (image)
+            if image.endswith("jpg") or image.endswith("JPG"):
+                if os.path.isfile(os.path.join(directory, image + ".csv")):
+                    with open(os.path.join(directory, image + ".csv"), "r") as csvfile:
+                        spamwriter = csv.reader(
+                            csvfile,
+                            delimiter=" ",
+                            quotechar="|",
+                            quoting=csv.QUOTE_MINIMAL,
+                        )
+
+                        img_path = os.path.join(directory, image)
+
+                        gt = []
+                        for row in spamwriter:
+                            gt.append(row)
+                        gt = np.array(gt).astype(np.float32)
+                        ground_truth = utils.sort_gt(gt)
+                        self.labels.append(ground_truth)
+                        self.data.append(img_path)
+
+        self.labels = np.array(self.labels)
+
+        self.labels = np.reshape(self.labels, (-1, 8))
+        logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
+        logger.debug("Data shape %s", str(len(self.data)))
+
+        self.myData = []
+        for a in range(len(self.data)):
+            self.myData.append([self.data[a], self.labels[a]])
+
+
+class SmartDocCorner(Dataset):
+    """
+    Class to include MNIST specific details
+    """
+
+    def __init__(self, directory="data"):
+        super().__init__("smartdoc")
+        self.data = []
+        self.labels = []
+        for d in directory:
+            self.directory = d
+            self.train_transform = transforms.Compose(
+                [
+                    getTransformsByImgaug(),
+                    transforms.ToTensor(),
+                ]
+            )
+
+            self.test_transform = transforms.Compose(
+                [
+                    iaa.Sequential(
+                        [
+                            iaa.Resize(32),
+                        ]
+                    ).augment_image,
+                    transforms.ToTensor(),
+                ]
+            )
+
+            logger.info("Pass train/test data paths here")
+
+            self.classes_list = {}
+
+            file_names = []
+            with open(os.path.join(self.directory, "gt.csv"), "r") as csvfile:
+                spamreader = csv.reader(
+                    csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL
+                )
+                import ast
+
+                for row in spamreader:
+                    file_names.append(row[0])
+                    self.data.append(os.path.join(self.directory, row[0]))
+                    test = row[1].replace("array", "")
+                    self.labels.append((ast.literal_eval(test)))
+        self.labels = np.array(self.labels)
+
+        self.labels = np.reshape(self.labels, (-1, 2))
+        logger.debug("Ground Truth Shape: %s", str(self.labels.shape))
+        logger.debug("Data shape %s", str(len(self.data)))
+
+        self.myData = [self.data, self.labels]

+ 19 - 0
Recursive-CNNs/dataprocessor/datasetfactory.py

@@ -0,0 +1,19 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import dataprocessor.dataset as data
+import torchvision
+
+class DatasetFactory:
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def get_dataset(directory, type="document"):
+        if type=="document":
+            return data.SmartDoc(directory)
+        elif type =="corner":
+            return data.SmartDocCorner(directory)
+        elif type=="CIFAR":
+            return torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())

+ 21 - 0
Recursive-CNNs/dataprocessor/loaderfactory.py

@@ -0,0 +1,21 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+
+
+import dataprocessor.dataloaders as loader
+
+
+class LoaderFactory:
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def get_loader(type, data, transform=None, cuda=False):
+        if type=="hdd":
+            return loader.HddLoader(data, transform=transform,
+                                    cuda=cuda)
+        elif type =="ram":
+            return loader.RamLoader(data, transform=transform,
+                                    cuda=cuda)

+ 65 - 0
Recursive-CNNs/demo.py

@@ -0,0 +1,65 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import cv2
+import numpy as np
+import glob
+import evaluation
+import os
+import shutil
+import time
+
+
+def args_processor():
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--images", default="example_imgs", help="Document image folder")
+    parser.add_argument('--model-type', default="resnet",
+                    help='model type to be used. Example : resnet32, resnet20, densenet, test')
+    parser.add_argument("-o", "--output", default="example_imgs/output", help="The folder to store results")
+    parser.add_argument("-rf", "--retainFactor", help="Floating point in range (0,1) specifying retain factor",
+                        default="0.85", type=float)
+    parser.add_argument("-cm", "--cornerModel", help="Model for corner point refinement",
+                        default="../cornerModelWell")
+    parser.add_argument("-dm", "--documentModel", help="Model for document corners detection",
+                        default="../documentModelWell")
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = args_processor()
+
+    corners_extractor = evaluation.corner_extractor.GetCorners(args.documentModel, args.model_type)
+    corner_refiner = evaluation.corner_refiner.corner_finder(args.cornerModel, args.model_type)
+    now_date = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime(time.time()))
+    output_dir = f"{args.output}_{now_date}"
+    shutil.rmtree(output_dir, ignore_errors=True)
+    os.makedirs(output_dir)
+    imgPaths = glob.glob(f"{args.images}/*.jpg")
+    for imgPath in imgPaths:
+        img = cv2.imread(imgPath)
+        oImg = img
+        e1 = cv2.getTickCount()
+        extracted_corners = corners_extractor.get(oImg)
+        corner_address = []
+        # Refine the detected corners using corner refiner
+        image_name = 0
+        for corner in extracted_corners:
+            image_name += 1
+            corner_img = corner[0]
+            refined_corner = np.array(corner_refiner.get_location(corner_img, args.retainFactor))
+
+            # Converting from local co-ordinate to global co-ordinates of the image
+            refined_corner[0] += corner[1]
+            refined_corner[1] += corner[2]
+
+            # Final results
+            corner_address.append(refined_corner)
+        e2 = cv2.getTickCount()
+        print(f"Took time:{(e2 - e1)/ cv2.getTickFrequency()}")
+
+        for a in range(0, len(extracted_corners)):
+            cv2.line(oImg, tuple(corner_address[a % 4]), tuple(corner_address[(a + 1) % 4]), (255, 0, 0), 4)
+        filename = os.path.basename(imgPath)
+        cv2.imwrite(f"{output_dir}/{filename}", oImg)

+ 87 - 0
Recursive-CNNs/document_data_generator.py

@@ -0,0 +1,87 @@
+import os
+from tqdm import tqdm
+
+import cv2
+import numpy as np
+import utils
+import dataprocessor
+
+import argparse
+def str2bool(v):
+    if isinstance(v, bool):
+        return v
+    if v.lower() in ('yes', 'true', 't', 'y', '1'):
+        return True
+    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+        return False
+    else:
+        raise argparse.ArgumentTypeError('Boolean value expected.')
+
+def args_processor():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--input-dir", help="Path to data files (Extract images using video_to_image.py first")
+    parser.add_argument("-o", "--output-dir", help="Directory to store results")
+    parser.add_argument("-v", "--visualize", help="Draw the point on the corner", default=False, type=bool)
+    parser.add_argument("-a", "--augment", type=str2bool, nargs='?',
+                        const=True, default=True,
+                        help="Augment image dataset")
+    parser.add_argument("--dataset", default="smartdoc", help="'smartdoc' or 'selfcollected' dataset")
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    if __name__ == '__main__':
+        args = args_processor()
+        input_directory = args.input_dir
+        if not os.path.isdir(args.output_dir):
+            os.mkdir(args.output_dir)
+        import csv
+
+
+        # Dataset iterator
+        if args.dataset == "smartdoc":
+            dataset_test = dataprocessor.dataset.SmartDocDirectories(input_directory)
+        elif args.dataset == "selfcollected":
+            dataset_test = dataprocessor.dataset.SelfCollectedDataset(input_directory)
+        else:
+            print("Incorrect dataset type; please choose between smartdoc or selfcollected")
+            assert (False)
+        with open(os.path.join(args.output_dir, 'gt.csv'), 'a') as csvfile:
+            spamwriter = csv.writer(csvfile, delimiter=',',
+                                    quotechar='|', quoting=csv.QUOTE_MINIMAL)
+            # Counter for file naming
+            counter = 0
+            for data_elem in tqdm(dataset_test.myData):
+
+                img_path = data_elem[0]
+                target = data_elem[1].reshape((4, 2))
+                img = cv2.imread(img_path)
+
+                if args.dataset == "selfcollected":
+                    target = target / (img.shape[1], img.shape[0])
+                    target = target * (1920, 1920)
+                    img = cv2.resize(img, (1920, 1920))
+
+                corner_cords = target
+                angles = [0, 271, 90] if args.augment else [0]
+                random_crops = [0, 16] if args.augment else [0]
+                for angle in angles:
+                    img_rotate, gt_rotate = utils.utils.rotate(img, corner_cords, angle)
+                    for random_crop in random_crops:
+                        counter += 1
+                        f_name = str(counter).zfill(8)
+
+                        img_crop, gt_crop = utils.utils.random_crop(img_rotate, gt_rotate)
+                        mah_size = img_crop.shape
+                        img_crop = cv2.resize(img_crop, (64, 64))
+                        gt_crop = np.array(gt_crop)
+
+                        if (args.visualize):
+                            no=0
+                            for a in range(0,4):
+                                no+=1
+                                cv2.circle(img_crop, tuple(((gt_crop[a]*64).astype(int))), 2,(255-no*60,no*60,0),9)
+                        # # cv2.imwrite("asda.jpg", img)
+
+                        cv2.imwrite(os.path.join(args.output_dir, f_name+".jpg"), img_crop)
+                        spamwriter.writerow((f_name+".jpg", tuple(list(gt_crop))))

+ 69 - 0
Recursive-CNNs/evaluate.py

@@ -0,0 +1,69 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import argparse
+import time
+
+import numpy as np
+import torch
+from PIL import Image
+
+import dataprocessor
+import evaluation
+
+from utils import utils
+
+parser = argparse.ArgumentParser(description='iCarl2.0')
+
+parser.add_argument("-i", "--data-dir", default="/Users/khurramjaved96/bg5",
+                    help="input Directory of test data")
+
+args = parser.parse_args()
+args.cuda = torch.cuda.is_available()
+if __name__ == '__main__':
+    corners_extractor = evaluation.corner_extractor.GetCorners("../documentModelWell")
+    corner_refiner = evaluation.corner_refiner.corner_finder("../cornerModelWell")
+    test_set_dir = args.data_dir
+    iou_results = []
+    my_results = []
+    dataset_test = dataprocessor.dataset.SmartDocDirectories(test_set_dir)
+    for data_elem in dataset_test.myData:
+
+        img_path = data_elem[0]
+        # print(img_path)
+        target = data_elem[1].reshape((4, 2))
+        img_array = np.array(Image.open(img_path))
+        computation_start_time = time.clock()
+        extracted_corners = corners_extractor.get(img_array)
+        temp_time = time.clock()
+        corner_address = []
+        # Refine the detected corners using corner refiner
+        counter=0
+        for corner in extracted_corners:
+            counter+=1
+            corner_img = corner[0]
+            refined_corner = np.array(corner_refiner.get_location(corner_img, 0.85))
+
+            # Converting from local co-ordinate to global co-ordinate of the image
+            refined_corner[0] += corner[1]
+            refined_corner[1] += corner[2]
+
+            # Final results
+            corner_address.append(refined_corner)
+        computation_end_time = time.clock()
+        print("TOTAL TIME : ", computation_end_time - computation_start_time)
+        r2 = utils.intersection_with_corection_smart_doc_implementation(target, np.array(corner_address), img_array)
+        r3 = utils.intersection_with_corection(target, np.array(corner_address), img_array)
+
+        if r3 - r2 > 0.1:
+            print ("Image Name", img_path)
+            print ("Prediction", np.array(corner_address), target)
+            0/0
+        assert (r2 > 0 and r2 < 1)
+        iou_results.append(r2)
+        my_results.append(r3)
+        print("MEAN CORRECTED JI: ", np.mean(np.array(iou_results)))
+        print("MEAN CORRECTED MY: ", np.mean(np.array(my_results)))
+
+    print(np.mean(np.array(iou_results)))

+ 6 - 0
Recursive-CNNs/evaluation/__init__.py

@@ -0,0 +1,6 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import evaluation.corner_extractor
+import evaluation.corner_refiner

+ 72 - 0
Recursive-CNNs/evaluation/corner_extractor.py

@@ -0,0 +1,72 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+
+import model
+
+
+class GetCorners:
+    def __init__(self, checkpoint_dir, model_type = "resnet"):
+        self.model = model.ModelFactory.get_model(model_type, 'document')
+        self.model.load_state_dict(torch.load(checkpoint_dir, map_location='cpu'))
+        if torch.cuda.is_available():
+            self.model.cuda()
+        self.model.eval()
+
+    def get(self, pil_image):
+        with torch.no_grad():
+            image_array = np.copy(pil_image)
+            pil_image = Image.fromarray(pil_image)
+            test_transform = transforms.Compose([transforms.Resize([32, 32]),
+                                                 transforms.ToTensor()])
+            img_temp = test_transform(pil_image)
+
+            img_temp = img_temp.unsqueeze(0)
+            if torch.cuda.is_available():
+                img_temp = img_temp.cuda()
+
+            model_prediction = self.model(img_temp).cpu().data.numpy()[0]
+
+            model_prediction = np.array(model_prediction)
+
+            x_cords = model_prediction[[0, 2, 4, 6]]
+            y_cords = model_prediction[[1, 3, 5, 7]]
+
+            x_cords = x_cords * image_array.shape[1]
+            y_cords = y_cords * image_array.shape[0]
+
+            # Extract the four corners of the image. Read "Region Extractor" in Section III of the paper for an explanation.
+
+            top_left = image_array[
+                       max(0, int(2 * y_cords[0] - (y_cords[3] + y_cords[0]) / 2)):int((y_cords[3] + y_cords[0]) / 2),
+                       max(0, int(2 * x_cords[0] - (x_cords[1] + x_cords[0]) / 2)):int((x_cords[1] + x_cords[0]) / 2)]
+
+            top_right = image_array[
+                        max(0, int(2 * y_cords[1] - (y_cords[1] + y_cords[2]) / 2)):int((y_cords[1] + y_cords[2]) / 2),
+                        int((x_cords[1] + x_cords[0]) / 2):min(image_array.shape[1] - 1,
+                                                               int(x_cords[1] + (x_cords[1] - x_cords[0]) / 2))]
+
+            bottom_right = image_array[int((y_cords[1] + y_cords[2]) / 2):min(image_array.shape[0] - 1, int(
+                y_cords[2] + (y_cords[2] - y_cords[1]) / 2)),
+                           int((x_cords[2] + x_cords[3]) / 2):min(image_array.shape[1] - 1,
+                                                                  int(x_cords[2] + (x_cords[2] - x_cords[3]) / 2))]
+
+            bottom_left = image_array[int((y_cords[0] + y_cords[3]) / 2):min(image_array.shape[0] - 1, int(
+                y_cords[3] + (y_cords[3] - y_cords[0]) / 2)),
+                          max(0, int(2 * x_cords[3] - (x_cords[2] + x_cords[3]) / 2)):int(
+                              (x_cords[3] + x_cords[2]) / 2)]
+
+            top_left = (top_left, max(0, int(2 * x_cords[0] - (x_cords[1] + x_cords[0]) / 2)),
+                        max(0, int(2 * y_cords[0] - (y_cords[3] + y_cords[0]) / 2)))
+            top_right = (
+            top_right, int((x_cords[1] + x_cords[0]) / 2), max(0, int(2 * y_cords[1] - (y_cords[1] + y_cords[2]) / 2)))
+            bottom_right = (bottom_right, int((x_cords[2] + x_cords[3]) / 2), int((y_cords[1] + y_cords[2]) / 2))
+            bottom_left = (bottom_left, max(0, int(2 * x_cords[3] - (x_cords[2] + x_cords[3]) / 2)),
+                           int((y_cords[0] + y_cords[3]) / 2))
+
+            return top_left, top_right, bottom_right, bottom_left

+ 84 - 0
Recursive-CNNs/evaluation/corner_refiner.py

@@ -0,0 +1,84 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+
+import model
+
+
+class corner_finder():
+    def __init__(self, CHECKPOINT_DIR, model_type = "resnet"):
+
+        self.model = model.ModelFactory.get_model(model_type, "corner")
+        self.model.load_state_dict(torch.load(CHECKPOINT_DIR, map_location='cpu'))
+        if torch.cuda.is_available():
+            self.model.cuda()
+        self.model.eval()
+
+    def get_location(self, img, retainFactor=0.85):
+        with torch.no_grad():
+            ans_x = 0.0
+            ans_y = 0.0
+
+            o_img = np.copy(img)
+
+            y = [0, 0]
+            x_start = 0
+            y_start = 0
+            up_scale_factor = (img.shape[1], img.shape[0])
+
+            myImage = np.copy(o_img)
+
+            test_transform = transforms.Compose([transforms.Resize([32, 32]),
+                                                 transforms.ToTensor()])
+
+            CROP_FRAC = retainFactor
+            while (myImage.shape[0] > 10 and myImage.shape[1] > 10):
+
+                img_temp = Image.fromarray(myImage)
+                img_temp = test_transform(img_temp)
+                img_temp = img_temp.unsqueeze(0)
+
+                if torch.cuda.is_available():
+                    img_temp = img_temp.cuda()
+                response = self.model(img_temp).cpu().data.numpy()
+                response = response[0]
+
+                response_up = response
+
+                response_up = response_up * up_scale_factor
+                y = response_up + (x_start, y_start)
+                x_loc = int(y[0])
+                y_loc = int(y[1])
+
+                if x_loc > myImage.shape[1] / 2:
+                    start_x = min(x_loc + int(round(myImage.shape[1] * CROP_FRAC / 2)), myImage.shape[1]) - int(round(
+                        myImage.shape[1] * CROP_FRAC))
+                else:
+                    start_x = max(x_loc - int(myImage.shape[1] * CROP_FRAC / 2), 0)
+                if y_loc > myImage.shape[0] / 2:
+                    start_y = min(y_loc + int(myImage.shape[0] * CROP_FRAC / 2), myImage.shape[0]) - int(
+                        myImage.shape[0] * CROP_FRAC)
+                else:
+                    start_y = max(y_loc - int(myImage.shape[0] * CROP_FRAC / 2), 0)
+
+                ans_x += start_x
+                ans_y += start_y
+
+                myImage = myImage[start_y:start_y + int(myImage.shape[0] * CROP_FRAC),
+                          start_x:start_x + int(myImage.shape[1] * CROP_FRAC)]
+                img = img[start_y:start_y + int(img.shape[0] * CROP_FRAC),
+                      start_x:start_x + int(img.shape[1] * CROP_FRAC)]
+                up_scale_factor = (img.shape[1], img.shape[0])
+
+            ans_x += y[0]
+            ans_y += y[1]
+            return (int(round(ans_x)), int(round(ans_y)))
+
+
+if __name__ == "__main__":
+    pass

+ 1 - 0
Recursive-CNNs/experiment/__init__.py

@@ -0,0 +1 @@
+from experiment.experiment import *

二進制
Recursive-CNNs/experiment/__pycache__/__init__.cpython-38.pyc


二進制
Recursive-CNNs/experiment/__pycache__/experiment.cpython-38.pyc


+ 53 - 0
Recursive-CNNs/experiment/experiment.py

@@ -0,0 +1,53 @@
+''' Incremental-Classifier Learning 
+ Authors : Khurram Javed, Muhammad Talha Paracha
+ Maintainer : Khurram Javed
+ Lab : TUKL-SEECS R&D Lab
+ Email : 14besekjaved@seecs.edu.pk '''
+
+import json
+import os
+import subprocess
+
+
+class experiment:
+    '''
+    Class to store results of any experiment 
+    '''
+
+    def __init__(self, name, args, output_dir="../"):
+        self.gitHash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode("utf-8")
+        print(self.gitHash)
+        if not args is None:
+            self.name = name
+            self.params = vars(args)
+            self.results = {}
+            self.dir = output_dir
+
+            import datetime
+            now = datetime.datetime.now()
+            rootFolder = str(now.day) + str(now.month) + str(now.year)
+            if not os.path.exists(output_dir + rootFolder):
+                os.makedirs(output_dir + rootFolder)
+            self.name = rootFolder + "/" + self.name
+            ver = 0
+
+            while os.path.exists(output_dir + self.name + "_" + str(ver)):
+                ver += 1
+
+            os.makedirs(output_dir + self.name + "_" + str(ver))
+            self.path = output_dir + self.name + "_" + str(ver) + "/" + name
+
+            self.results["Temp Results"] = [[1, 2, 3, 4], [5, 6, 2, 6]]
+
+    def store_json(self):
+        with open(self.path + "JSONDump.txt", 'w') as outfile:
+            json.dump(json.dumps(self.__dict__), outfile)
+
+
+import argparse
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='iCarl2.0')
+    args = parser.parse_args()
+    e = experiment("TestExperiment", args)
+    e.store_json()

+ 61 - 0
Recursive-CNNs/mobile_model_converter.py

@@ -0,0 +1,61 @@
+import argparse
+import shutil
+from pathlib import Path
+import os
+import torch
+from tinynn.converter import TFLiteConverter
+import model
+
+parser = argparse.ArgumentParser()
+parser.add_argument("-cm", "--cornerModel", help="Model for corner point refinement",
+                    default="../cornerModelWell")
+parser.add_argument("-dm", "--documentModel", help="Model for document corners detection",
+                    default="../documentModelWell")
+
+def load_doc_model(checkpoint_dir, dataset):
+    _model = model.ModelFactory.get_model("resnet", dataset)
+    _model.load_state_dict(torch.load(checkpoint_dir, map_location="cpu"))
+    return _model
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    models = [
+        {
+            "name": "corner_model",
+            "model": load_doc_model(
+                args.cornerModel,
+                "corner",
+            ),
+        },
+        {
+            "name": "doc_model",
+            "model": load_doc_model(
+                args.documentModel,
+                "document",
+            ),
+        },
+    ]
+
+    out_dir = "output_tflite"
+    shutil.rmtree(out_dir, ignore_errors=True)
+    os.mkdir(out_dir)
+    for item in models:
+        _model = item["model"]
+        _model.eval()
+
+        dummy_input = torch.rand((1, 3, 32, 32))
+        modelPath = f'{out_dir}/{item["name"]}.tflite'
+        converter = TFLiteConverter(_model, dummy_input, modelPath)
+        converter.convert()
+        # scripted = torch.jit.script(_model)
+
+        # optimized_model = optimize_for_mobile(scripted, backend='metal')
+        # print(torch.jit.export_opnames(optimized_model))
+        # optimized_model._save_for_lite_interpreter(f'{output}/{item["name"]}_metal.ptl')
+        
+        # scripted_model = torch.jit.script(_model)
+        # optimized_model = optimize_for_mobile(scripted_model, backend='metal')
+        # print(torch.jit.export_opnames(optimized_model))
+        # optimized_model._save_for_lite_interpreter(f'{output}/{item["name"]}_metal.pt')
+
+        # torch.save(_model, f'{output}/{item["name"]}.pth')

+ 1 - 0
Recursive-CNNs/model/__init__.py

@@ -0,0 +1 @@
+from model.modelfactory import *

二進制
Recursive-CNNs/model/__pycache__/__init__.cpython-38.pyc


二進制
Recursive-CNNs/model/__pycache__/cornerModel.cpython-38.pyc


二進制
Recursive-CNNs/model/__pycache__/modelfactory.cpython-38.pyc


二進制
Recursive-CNNs/model/__pycache__/res_utils.cpython-38.pyc


二進制
Recursive-CNNs/model/__pycache__/resnet32.cpython-38.pyc


+ 78 - 0
Recursive-CNNs/model/cornerModel.py

@@ -0,0 +1,78 @@
+# Reference : Taken from https://github.com/kuangliu/pytorch-cifar
+
+# License
+# MIT License
+#
+# Copyright (c) 2017 liukuang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+'''MobileNet in PyTorch.
+See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
+for more details.
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Block(nn.Module):
+    '''Depthwise conv + Pointwise conv'''
+    def __init__(self, in_planes, out_planes, stride=1):
+        super(Block, self).__init__()
+        self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
+        self.bn1 = nn.BatchNorm2d(in_planes)
+        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
+        self.bn2 = nn.BatchNorm2d(out_planes)
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = F.relu(self.bn2(self.conv2(out)))
+        return out
+
+
+class MobileNet(nn.Module):
+    # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
+    cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024]
+
+    def __init__(self, num_classes=10):
+        super(MobileNet, self).__init__()
+        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(32)
+        self.layers = self._make_layers(in_planes=32)
+        self.linear = nn.Linear(1024, num_classes)
+
+    def _make_layers(self, in_planes):
+        layers = []
+        for x in self.cfg:
+            out_planes = x if isinstance(x, int) else x[0]
+            stride = 1 if isinstance(x, int) else x[1]
+            layers.append(Block(in_planes, out_planes, stride))
+            in_planes = out_planes
+        return nn.Sequential(*layers)
+
+    def forward(self, x, pretrain=False):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layers(out)
+        out = F.avg_pool2d(out, 2)
+        out = out.view(out.size(0), -1)
+        out = self.linear(out)
+        return out
+

+ 37 - 0
Recursive-CNNs/model/modelfactory.py

@@ -0,0 +1,37 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import model.resnet32 as resnet
+import model.cornerModel as tm
+import torchvision.models as models
+
+class ModelFactory():
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def get_model(model_type, dataset):
+        if model_type == "resnet":
+            if dataset == 'document':
+                return resnet.resnet20(8)
+            elif dataset == 'corner':
+                return resnet.resnet20(2)
+        if model_type == "resnet8":
+            if dataset == 'document':
+                return resnet.resnet8(8)
+            elif dataset == 'corner':
+                return resnet.resnet8(2)
+        elif model_type == 'shallow':
+            if dataset == 'document':
+                return tm.MobileNet(8)
+            elif dataset == 'corner':
+                return tm.MobileNet(2)
+        elif model_type =="squeeze":
+            if dataset == 'document':
+                return models.squeezenet1_1(True)
+            elif dataset == 'corner':
+                return models.squeezenet1_1(True)
+        else:
+            print("Unsupported model; either implement the model in model/ModelFactory or choose a different model")
+            assert (False)

+ 37 - 0
Recursive-CNNs/model/res_utils.py

@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+
+
+class DownsampleA(nn.Module):
+    def __init__(self, nIn, nOut, stride):
+        super(DownsampleA, self).__init__()
+        assert stride == 2
+        self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
+
+    def forward(self, x):
+        x = self.avg(x)
+        return torch.cat((x, x.mul(0)), 1)
+
+
+class DownsampleC(nn.Module):
+    def __init__(self, nIn, nOut, stride):
+        super(DownsampleC, self).__init__()
+        assert stride != 1 or nIn != nOut
+        self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+
+class DownsampleD(nn.Module):
+    def __init__(self, nIn, nOut, stride):
+        super(DownsampleD, self).__init__()
+        assert stride == 2
+        self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
+        self.bn = nn.BatchNorm2d(nOut)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x

+ 184 - 0
Recursive-CNNs/model/resnet32.py

@@ -0,0 +1,184 @@
+# This is someone elses implementation of resnet optimized for CIFAR; I can't seem to find the repository again to reference the work.
+# I will keep on looking.
+import math
+
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+
+from .res_utils import DownsampleA
+
+
+class ResNetBasicblock(nn.Module):
+    expansion = 1
+    """
+    RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua)
+    """
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(ResNetBasicblock, self).__init__()
+
+        self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn_a = nn.BatchNorm2d(planes)
+
+        self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn_b = nn.BatchNorm2d(planes)
+
+        self.downsample = downsample
+        self.featureSize = 64
+
+    def forward(self, x):
+        residual = x
+
+        basicblock = self.conv_a(x)
+        basicblock = self.bn_a(basicblock)
+        basicblock = F.relu(basicblock, inplace=True)
+
+        basicblock = self.conv_b(basicblock)
+        basicblock = self.bn_b(basicblock)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        return F.relu(residual + basicblock, inplace=True)
+
+
+class CifarResNet(nn.Module):
+    """
+    ResNet optimized for the Cifar Dataset, as specified in
+    https://arxiv.org/abs/1512.03385.pdf
+    """
+
+    def __init__(self, block, depth, num_classes, channels=3):
+        """ Constructor
+        Args:
+          depth: number of layers.
+          num_classes: number of classes
+          base_width: base width
+        """
+        super(CifarResNet, self).__init__()
+
+        self.featureSize = 64
+        # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
+        assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
+        layer_blocks = (depth - 2) // 6
+
+        self.num_classes = num_classes
+
+        self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn_1 = nn.BatchNorm2d(16)
+
+        self.inplanes = 16
+        self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
+        self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
+        self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
+        self.avgpool = nn.AvgPool2d(8)
+        self.fc = nn.Linear(64 * block.expansion, num_classes)
+        self.fc2 = nn.Linear(64 * block.expansion, 100)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+                # m.bias.data.zero_()
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+            elif isinstance(m, nn.Linear):
+                init.kaiming_normal(m.weight)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x, pretrain:bool=False):
+
+        x = self.conv_1_3x3(x)
+        x = F.relu(self.bn_1(x), inplace=True)
+        x = self.stage_1(x)
+        x = self.stage_2(x)
+        x = self.stage_3(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        if pretrain:
+            return self.fc2(x)
+        x = self.fc(x)
+        return x
+
+
+def resnet20(num_classes=10):
+    """Constructs a ResNet-20 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 20, num_classes)
+    return model
+
+
+def resnet8(num_classes=10):
+    """Constructs a ResNet-20 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 8, num_classes, 3)
+    return model
+
+
+def resnet20mnist(num_classes=10):
+    """Constructs a ResNet-20 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 20, num_classes, 1)
+    return model
+
+
+def resnet32mnist(num_classes=10, channels=1):
+    model = CifarResNet(ResNetBasicblock, 32, num_classes, channels)
+    return model
+
+
+def resnet32(num_classes=10):
+    """Constructs a ResNet-32 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 32, num_classes)
+    return model
+
+
+def resnet44(num_classes=10):
+    """Constructs a ResNet-44 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 44, num_classes)
+    return model
+
+
+def resnet56(num_classes=10):
+    """Constructs a ResNet-56 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 56, num_classes)
+    return model
+
+
+def resnet110(num_classes=10):
+    """Constructs a ResNet-110 model for CIFAR-10 (by default)
+    Args:
+      num_classes (uint): number of classes
+    """
+    model = CifarResNet(ResNetBasicblock, 110, num_classes)
+    return model

+ 1 - 0
Recursive-CNNs/plotter/__init__.py

@@ -0,0 +1 @@
+from plotter.plotter import *

+ 88 - 0
Recursive-CNNs/plotter/plotter.py

@@ -0,0 +1,88 @@
+''' Incremental-Classifier Learning 
+ Authors : Khurram Javed, Muhammad Talha Paracha
+ Maintainer : Khurram Javed
+ Lab : TUKL-SEECS R&D Lab
+ Email : 14besekjaved@seecs.edu.pk '''
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+plt.switch_backend('agg')
+
+MEDIUM_SIZE = 18
+
+font = {'family': 'sans-serif',
+        'weight': 'bold'}
+
+matplotlib.rc('xtick', labelsize=MEDIUM_SIZE)
+matplotlib.rc('ytick', labelsize=MEDIUM_SIZE)
+plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
+
+# matplotlib.rc('font', **font)
+from matplotlib import rcParams
+
+rcParams.update({'figure.autolayout': True})
+
+
+class Plotter():
+    def __init__(self):
+        import itertools
+        # plt.figure(figsize=(12, 9))
+        self.marker = itertools.cycle(('o', '+', "v", "^", "8", '.', '*'))
+        self.handles = []
+        self.lines = itertools.cycle(('--', '-.', '-', ':'))
+
+    def plot(self, x, y, xLabel="Number of Classes", yLabel="Accuracy %", legend="none", title=None, error=None):
+        self.x = x
+        self.y = y
+        plt.grid(color='0.89', linestyle='--', linewidth=1.0)
+        if error is None:
+            l, = plt.plot(x, y, linestyle=next(self.lines), marker=next(self.marker), label=legend, linewidth=3.0)
+        else:
+            l = plt.errorbar(x, y, yerr=error, capsize=4.0, capthick=2.0, linestyle=next(self.lines),
+                             marker=next(self.marker), label=legend, linewidth=3.0)
+
+        self.handles.append(l)
+        self.x_label = xLabel
+        self.y_label = yLabel
+        if title is not None:
+            plt.title(title)
+
+    def save_fig(self, path, xticks=105, title=None, yStart=0, xRange=0, yRange=10):
+        if title is not None:
+            plt.title(title)
+        plt.legend(handles=self.handles)
+        plt.ylim((yStart, 100 + 0.2))
+        plt.xlim((0, xticks + .2))
+        plt.ylabel(self.y_label)
+        plt.xlabel(self.x_label)
+        plt.yticks(list(range(yStart, 101, yRange)))
+        print(list(range(yStart, 105, yRange)))
+        plt.xticks(list(range(0, xticks + 1, xRange + int(xticks / 10))))
+        plt.savefig(path + ".eps", format='eps')
+        plt.gcf().clear()
+
+    def save_fig2(self, path, xticks=105):
+        plt.legend(handles=self.handles)
+        plt.xlabel("Memory Budget")
+        plt.ylabel("Average Incremental Accuracy")
+        plt.savefig(path + ".jpg")
+        plt.gcf().clear()
+
+    def plotMatrix(self, epoch, path, img):
+
+        plt.imshow(img, cmap='plasma', interpolation='nearest')
+        plt.colorbar()
+        plt.savefig(path + str(epoch) + ".svg", format='svg')
+        plt.gcf().clear()
+
+    def saveImage(self, img, path, epoch):
+        from PIL import Image
+        im = Image.fromarray(img)
+        im.save(path + str(epoch) + ".jpg")
+
+
+if __name__ == "__main__":
+    pl = Plotter()
+    pl.plot([1, 2, 3, 4], [2, 3, 6, 2])
+    pl.save_fig("test.jpg")

+ 54 - 0
Recursive-CNNs/requirements.txt

@@ -0,0 +1,54 @@
+anyio==3.4.0
+black==21.11b1
+certifi==2021.10.8
+charset-normalizer==2.0.7
+click==8.0.3
+cycler==0.11.0
+fonttools==4.28.2
+h11==0.12.0
+httpcore==0.14.3
+httpx==0.21.1
+idna==3.3
+imageio==2.13.1
+imgaug==0.4.0
+jsonpatch==1.32
+jsonpointer==2.2
+kiwisolver==1.3.2
+matplotlib==3.5.0
+mypy-extensions==0.4.3
+networkx==2.6.3
+numpy==1.21.4
+opencv-python==4.5.4.60
+packaging==21.3
+pathspec==0.9.0
+Pillow==8.4.0
+platformdirs==2.4.0
+Polygon3==3.0.9.1
+pyparsing==3.0.6
+PySocks==1.7.1
+python-dateutil==2.8.2
+PyWavelets==1.2.0
+pyzmq==22.3.0
+regex==2021.11.10
+requests==2.26.0
+rfc3986==1.5.0
+scikit-image==0.19.0
+scipy==1.7.2
+setuptools-scm==6.3.2
+Shapely==1.8.0
+six==1.16.0
+sniffio==1.2.0
+tifffile==2021.11.2
+tomli==1.2.2
+torch==1.10.0
+torchaudio==0.10.0
+torchfile==0.1.0
+torchnet==0.0.4
+torchvision==0.11.1
+tornado==6.1
+tqdm==4.62.3
+typing_extensions==4.0.0
+urllib3==1.26.7
+visdom==0.1.8.9
+websocket-client==1.2.1
+websockets==10.1

二進制
Recursive-CNNs/results/qualitativeResults.jpg


+ 142 - 0
Recursive-CNNs/self_collected_dataset_preprocess.py

@@ -0,0 +1,142 @@
+from genericpath import exists
+import glob
+import cv2
+import os
+import shutil
+import csv
+import random
+import numpy as np
+
+class bcolors:
+    HEADER = '\033[95m'
+    OKBLUE = '\033[94m'
+    OKCYAN = '\033[96m'
+    OKGREEN = '\033[92m'
+    WARNING = '\033[93m'
+    FAIL = '\033[91m'
+    ENDC = '\033[0m'
+    BOLD = '\033[1m'
+    UNDERLINE = '\033[4m'
+
+def args_processor():
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-i", "--input-dir", help="dataput")
+    parser.add_argument("-o", "--output-dir", help="Directory to store results")
+    return parser.parse_args()
+
+def orderPoints(pts, centerPt):
+    # size = len(pts)
+    # centerPt = [0, 0]
+    # for pt in pts:
+    #     centerPt[0] += pt[0] / size
+    #     centerPt[1] += pt[1] / size
+    # cv2.circle(img, tuple(list((np.array(centerPt)).astype(int))), 2, (255, 0, 0), 2)
+    # cv2.imshow("img", img)
+    # cv2.waitKey()
+    # cv2.destroyAllWindows()
+    orderedDict = {}
+    for pt in pts:
+        index = -1
+        if pt[0] < centerPt[0] and pt[1] < centerPt[1]:
+            index = 0
+        elif pt[0] > centerPt[0] and pt[1] < centerPt[1]:
+            index = 1
+        elif pt[0] < centerPt[0] and pt[1] > centerPt[1]: 
+            index = 3
+        elif pt[0] > centerPt[0] and pt[1] > centerPt[1]:
+            index = 2
+        if index in orderedDict:
+            targetKeys = [0, 1, 2, 3]
+            for i in range(4):
+                exists = False
+                for key in orderedDict.keys():
+                    if key == targetKeys[i]:
+                        exists = True
+                        break
+                if exists is False:
+                    index = targetKeys[i]
+                    break
+        orderedDict[index] = pt
+    orderedPts = list(dict(sorted(orderedDict.items())).values())
+    assert len(orderedPts) == 4
+    return orderedPts
+
+def isAvaibleImg(pts, img, centerPt):
+    h, w = img.shape[:2]
+    for i, pt in enumerate(pts):
+        if pt[0] > (w - 1) or pt[0] < 1:
+            return False
+        if pt[1] > (h - 1) or pt[1] < 1:
+            return False
+        if pt[0] == centerPt[0] or pt[1] == centerPt[1]:
+            return False
+        for _i, _pt in enumerate(pts):
+            if i == _i:
+                continue
+            if abs(pt[0] - _pt[0]) <= 3:
+                return False
+            if abs(pt[1] - _pt[1]) <= 3:
+                return False
+    return True
+
+def getCenterPt(pts):
+    size = len(pts)
+    centerPt = [0, 0]
+    for pt in pts:
+        centerPt[0] += pt[0] / size
+        centerPt[1] += pt[1] / size
+    return centerPt
+
+def process(imgpaths, out):
+    for imgpath in imgpaths:
+        csv_path = imgpath.split(".")[0] + ".csv"
+        if os.path.isfile(csv_path) == False:
+            continue
+        with open(csv_path, "r") as f:
+            reader = csv.reader(f, delimiter="\t")
+            pts = []
+            for i, line in enumerate(reader):
+                split = line[0].split(" ")
+                pt = [float(split[0]), float(split[1])]
+                pts.append(pt)
+        assert len(pts) == 4
+        img = cv2.imread(imgpath)
+        centerPt = getCenterPt(pts)
+        if isAvaibleImg(pts, img, centerPt) is False:
+            # print(f"{bcolors.WARNING}{imgpath} discard {bcolors.ENDC}")
+            continue
+        orderedPts = orderPoints(pts, centerPt)
+        # for count, pt in enumerate(orderedPts):
+        #     cv2.putText(img, f'{count}', (int(pt[0]), int(pt[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
+        # cv2.imshow('img',img)
+        # cv2.waitKey()
+        # cv2.destroyAllWindows()
+        fileName = os.path.basename(imgpath).split(".")[0]
+        out_imgpath = f"{out}/{fileName}.jpg"
+        with open(f"{out_imgpath}.csv", "w") as csv_out:
+            for pt in orderedPts:
+                csv_out.write(f"{pt[0]} {pt[1]}")
+                csv_out.write('\n')
+        cv2.imwrite(out_imgpath, img)
+
+
+if __name__ == "__main__":
+    args = args_processor()
+    imgpaths = glob.glob(f"{args.input_dir}/*.jpg") + glob.glob(
+        f"{args.input_dir}/*.png"
+    )
+    train_dataset_out = f"{args.output_dir}/train"
+    test_dataset_out = f"{args.output_dir}/test"
+    shutil.rmtree(args.output_dir, ignore_errors=True)
+    os.mkdir(args.output_dir)
+    os.mkdir(train_dataset_out)
+    os.mkdir(test_dataset_out)
+
+    imgpaths_num = len(imgpaths)
+    test_num = int(imgpaths_num * 0.2)
+    test_imgpaths = imgpaths[0:test_num]
+    train_imgpaths = imgpaths[test_num:imgpaths_num]
+    process(train_imgpaths, train_dataset_out)
+    process(test_imgpaths, test_dataset_out)

+ 40 - 0
Recursive-CNNs/smartdoc_data_processor/video_to_image.py

@@ -0,0 +1,40 @@
+import os
+
+
+def argsProcessor():
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-d", "--dataPath", help="path to main data folder")
+    parser.add_argument("-o", "--outputPath", help="output data")
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = argsProcessor()
+    dir = args.dataPath
+    output = args.outputPath
+    if (not os.path.isdir(output)):
+        os.mkdir(output)
+
+    for folder in os.listdir(dir):
+        if os.path.isdir(dir + "/" + folder):
+            dir_temp = dir + folder + "/"
+            for file in os.listdir(dir_temp):
+                print(file)
+                from subprocess import call
+
+                if (file.endswith(".avi")):
+                    call("mkdir " + output + folder, shell=True)
+                    if (os.path.isdir(output + folder + "/" + file)):
+                        print("Folder already exist")
+                    else:
+                        call("cd " + output + folder + " && mkdir " + file, shell=True)
+                        call("ls", shell=True)
+
+                        location = dir + folder + "/" + file
+                        gt_address = "cp " + location[
+                                             0:-4] + ".gt.xml " + output + folder + "/" + file + "/" + file + ".gt"
+                        call(gt_address, shell=True)
+                        command = "ffmpeg -i " + location + " " + output + folder + "/" + file + "/%3d.jpg"
+                        print(command)
+                        call(command, shell=True)

+ 101 - 0
Recursive-CNNs/sythetic_doc.py

@@ -0,0 +1,101 @@
+import shutil
+import os
+import glob
+import cv2
+import numpy as np
+import random
+import imgaug.augmenters as iaa
+
+
+def visualizeImg(list=[]):
+    for item in list:
+        cv2.imshow(item[0], item[1])
+    cv2.waitKey(0)
+    cv2.destroyAllWindows()
+
+
+def transformation(src):
+    height, width = src.shape[:2]
+    srcPts = np.array([[0, 0], [width, 0], [width, height], [0, height]]).astype(
+        np.float32
+    )
+    float_random_num = random.uniform(0.0, 0.3)
+    float_random_num2 = random.uniform(0.0, 0.3)
+    float_random_num3 = random.uniform(0.7, 1)
+    float_random_num4 = random.uniform(0.0, 0.3)
+    float_random_num5 = random.uniform(0.7, 1)
+    float_random_num6 = random.uniform(0.7, 1)
+    float_random_num7 = random.uniform(0.0, 0.3)
+    float_random_num8 = random.uniform(0.7, 1)
+    dstPts = np.array(
+        [
+            [width * float_random_num, height * float_random_num2],
+            [width * float_random_num3, height * float_random_num4],
+            [width * float_random_num5, height * float_random_num6],
+            [width * float_random_num7, height * float_random_num8],
+        ]
+    ).astype(np.float32)
+    M = cv2.getPerspectiveTransform(srcPts, dstPts)
+    # warp_dst = cv2.warpPerspective(src, M, (src.shape[1], src.shape[0]), cv2.INTER_LINEAR, cv2.BORDER_CONSTANT, 255)
+    warp_dst = cv2.warpPerspective(
+        src,
+        M,
+        (width, height),
+        flags = cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue = [0, 0, 0, 0]
+    )
+    # for pt in dstPts:
+    #     warp_dst = cv2.circle(warp_dst, (int(pt[0]), int(pt[1])), radius=4, color=(0, 0, 255), thickness=-1)
+    # visualizeImg([("warp_dst", warp_dst)])
+    return warp_dst
+
+
+def blending(img1, img2):
+    # I want to put logo on top-left corner, So I create a ROI
+    rows, cols, channels = img2.shape
+    roi = img1[0:rows, 0:cols]
+    # Now create a mask of logo and create its inverse mask also
+    img2gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
+    ret, mask = cv2.threshold(img2gray, 0, 255, cv2.THRESH_BINARY)
+    mask_inv = cv2.bitwise_not(mask)
+    # Now black-out the area of logo in ROI
+    img1_bg = cv2.bitwise_and(roi, roi, mask=mask_inv)
+    # Take only region of logo from logo image.
+    img2_fg = cv2.bitwise_and(img2, img2, mask=mask)
+    # Put logo in ROI and modify the main image
+    dst = cv2.add(img1_bg, img2_fg)
+    img1[0:rows, 0:cols] = dst
+    return img1
+
+
+def smoothEdge(blended_img):
+    up_sample_img = cv2.pyrUp(blended_img)
+    blur_img = up_sample_img.copy()
+    for i in range(4):
+        blur_img = cv2.medianBlur(blur_img, 21)
+    down_sample_img = cv2.pyrDown(blur_img)
+    return down_sample_img
+
+
+if __name__ == "__main__":
+    dataDir = (
+        "/Users/imac-1/workspace/hed-tutorial-for-document-scanning/sample_images/"
+    )
+    bk_imgs_folder = "background_images"
+    rect_folder = "rect_images"
+    bk_img_paths = glob.glob(f"{dataDir+bk_imgs_folder}/*.jpg")
+    rect_img_paths = glob.glob(f"{dataDir+rect_folder}/*.jpg")
+    outputDir = f"{dataDir}output"
+    shutil.rmtree(outputDir, ignore_errors=True)
+    os.makedirs(outputDir)
+    for bk_img_path in bk_img_paths:
+        bk_img = cv2.imread(bk_img_path)
+        for rect_img_path in rect_img_paths:
+            rect_img = cv2.imread(rect_img_path)
+            warpedImg = transformation(rect_img)
+            resized_img = cv2.resize(warpedImg, bk_img.shape[1::-1])
+            blended_img = blending(bk_img.copy(), resized_img)
+            final_img = smoothEdge(blended_img)
+            rectImgName = os.path.basename(rect_img_path).split(".")
+            bkImgName = os.path.basename(bk_img_path).split(".")
+            outputFile = f"{outputDir}/{bkImgName[0]}_{rectImgName[0]}.jpg"
+            cv2.imwrite(outputFile, final_img)

+ 146 - 0
Recursive-CNNs/train_model.py

@@ -0,0 +1,146 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+from __future__ import print_function
+
+import argparse
+
+import torch
+import torch.utils.data as td
+
+import dataprocessor
+import experiment as ex
+import model
+import trainer
+import utils
+
+parser = argparse.ArgumentParser(description='Recursive-CNNs')
+parser.add_argument('--batch-size', type=int, default=32, metavar='N',
+                    help='input batch size for training (default: 32)')
+parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
+                    help='learning rate (default: 0.005)')
+parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
+                    help='Decrease learning rate at these epochs.')
+parser.add_argument('--gammas', type=float, nargs='+', default=[0.2, 0.2, 0.2],
+                    help='LR is multiplied by gamma[k] on schedule[k], number of gammas should be equal to schedule')
+parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+                    help='SGD momentum (default: 0.9)')
+parser.add_argument('--no-cuda', action='store_true', default=False,
+                    help='disables CUDA training')
+parser.add_argument('--pretrain', action='store_true', default=False,
+                    help='Pretrain the model on CIFAR dataset?')
+parser.add_argument('--load-ram', action='store_true', default=False,
+                    help='Load data in ram: TODO : Remove this')
+parser.add_argument('--debug', action='store_true', default=True,
+                    help='Debug messages')
+parser.add_argument('--seed', type=int, default=2323,
+                    help='Seeds values to be used')
+parser.add_argument('--log-interval', type=int, default=5, metavar='N',
+                    help='how many batches to wait before logging training status')
+parser.add_argument('--model-type', default="resnet",
+                    help='model type to be used. Example : resnet32, resnet20, densenet, test')
+parser.add_argument('--name', default="noname",
+                    help='Name of the experiment')
+parser.add_argument('--output-dir', default="output/",
+                    help='Directory to store the results; a new folder "DDMMYYYY" will be created '
+                         'in the specified directory to save the results.')
+parser.add_argument('--decay', type=float, default=0.00001, help='Weight decay (L2 penalty).')
+parser.add_argument('--epochs', type=int, default=40, help='Number of epochs for trianing')
+parser.add_argument('--dataset', default="document", help='Dataset to be used; example document, corner')
+parser.add_argument('--loader', default="hdd", 
+                    help='Loader to load data; hdd for reading from the hdd and ram for loading all data in the memory')
+parser.add_argument("-i", "--data-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
+                    help="input Directory of train data")
+parser.add_argument("-v", "--validation-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
+                    help="input Directory of val data")
+
+args = parser.parse_args()
+
+# Define an experiment.
+my_experiment = ex.experiment(args.name, args, args.output_dir)
+
+# Add logging support
+logger = utils.utils.setup_logger(my_experiment.path)
+
+args.cuda = not args.no_cuda and torch.cuda.is_available()
+
+dataset = dataprocessor.DatasetFactory.get_dataset(args.data_dirs, args.dataset)
+
+dataset_val = dataprocessor.DatasetFactory.get_dataset(args.validation_dirs, args.dataset)
+
+# Fix the seed.
+seed = args.seed
+torch.manual_seed(seed)
+if args.cuda:
+    torch.cuda.manual_seed(seed)
+
+train_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset.myData,
+                                                              transform=dataset.train_transform,
+                                                              cuda=args.cuda)
+# Loader used for training data
+val_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset_val.myData,
+                                                            transform=dataset.test_transform,
+                                                            cuda=args.cuda)
+kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
+
+# Iterator to iterate over training data.
+train_iterator = torch.utils.data.DataLoader(train_dataset_loader,
+                                             batch_size=args.batch_size, shuffle=True, **kwargs)
+# Iterator to iterate over training data.
+val_iterator = torch.utils.data.DataLoader(val_dataset_loader,
+                                           batch_size=args.batch_size, shuffle=True, **kwargs)
+
+# Get the required model
+myModel = model.ModelFactory.get_model(args.model_type, args.dataset)
+if args.cuda:
+    myModel.cuda()
+
+# Should I pretrain the model on CIFAR?
+if args.pretrain:
+    trainset = dataprocessor.DatasetFactory.get_dataset(None, "CIFAR")
+    train_iterator_cifar = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
+
+    # Define the optimizer used in the experiment
+    cifar_optimizer = torch.optim.SGD(myModel.parameters(), args.lr, momentum=args.momentum,
+                                      weight_decay=args.decay, nesterov=True)
+
+    # Trainer object used for training
+    cifar_trainer = trainer.CIFARTrainer(train_iterator_cifar, myModel, args.cuda, cifar_optimizer)
+
+    for epoch in range(0, 70):
+        logger.info("Epoch : %d", epoch)
+        cifar_trainer.update_lr(epoch, [30, 45, 60], args.gammas)
+        cifar_trainer.train(epoch)
+
+    # Freeze the model
+    counter = 0
+    for name, param in myModel.named_parameters():
+        # Getting the length of total layers so I can freeze x% of layers
+        gen_len = sum(1 for _ in myModel.parameters())
+        if counter < int(gen_len * 0.5):
+            param.requires_grad = False
+            logger.warning(name)
+        else:
+            logger.info(name)
+        counter += 1
+
+# Define the optimizer used in the experiment
+optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, myModel.parameters()), args.lr,
+                            momentum=args.momentum,
+                            weight_decay=args.decay, nesterov=True)
+
+# Trainer object used for training
+my_trainer = trainer.Trainer(train_iterator, myModel, args.cuda, optimizer)
+
+# Evaluator
+my_eval = trainer.EvaluatorFactory.get_evaluator("rmse", args.cuda)
+# Running epochs_class epochs
+for epoch in range(0, args.epochs):
+    logger.info("Epoch : %d", epoch)
+    my_trainer.update_lr(epoch, args.schedule, args.gammas)
+    my_trainer.train(epoch)
+    my_eval.evaluate(my_trainer.model, val_iterator)
+
+torch.save(myModel.state_dict(), my_experiment.path + args.dataset + "_" + args.model_type+ ".pb")
+my_experiment.store_json()

+ 35 - 0
Recursive-CNNs/train_model.sh

@@ -0,0 +1,35 @@
+#!/bin/bash
+
+set -e # exit when any command fails
+
+if [ "$1" == "dev" ]; then
+    doc_dataset_train="dataset/selfCollectedData_DocCyclic"
+    doc_dataset_test="dataset/my_doc_test"
+
+    corner_dataset_train="dataset/my_corner_train"
+    corner_dataset_test="dataset/my_corner_test"
+else
+    doc_dataset_train="dataset/selfCollectedData_DocCyclic dataset/smartdocData_DocTrainC dataset/my_doc_train dataset/sythetic_doc_train"
+    doc_dataset_test="dataset/smartDocData_DocTestC dataset/my_doc_test dataset/sythetic_doc_test"
+
+    corner_dataset_train="dataset/cornerTrain64 dataset/my_corner_train dataset/sythetic_corner_train"
+    corner_dataset_test="dataset/selfCollectedData_CornDetec dataset/my_corner_test dataset/sythetic_corner_test"
+fi
+
+# echo "doc_dataset_train=$doc_dataset_train corner_dataset_test=$corner_dataset_test"
+
+# 1、resnet model
+
+python train_model.py --name DocModel -i $doc_dataset_train \
+--lr 0.5 --schedule 20 30 35 -v $doc_dataset_test --batch-size 8 --model-type resnet --loader ram
+
+python train_model.py --name CornerModel -i $corner_dataset_train \
+--lr 0.5 --schedule 20 30 35 -v $corner_dataset_test --batch-size 8 --model-type resnet --loader ram --dataset corner
+
+# 2、mobile_net model
+
+# python train_model.py --name DocModel -i $doc_dataset_train \
+# --lr 0.5 --schedule 20 30 35 -v $doc_dataset_test --batch-size 16 --model-type shallow --loader ram
+
+# python train_model.py --name CornerModel -i $corner_dataset_train \
+# --lr 0.5 --schedule 20 30 35 -v $corner_dataset_test --batch-size 16 --model-type shallow --loader ram --dataset corner

+ 149 - 0
Recursive-CNNs/train_seg_model.py

@@ -0,0 +1,149 @@
+## Document Localization using Recursive CNN
+## Maintainer : Khurram Javed
+## Email : kjaved@ualberta.ca
+
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+from __future__ import print_function
+
+import argparse
+
+import torch
+import torch.utils.data as td
+
+import dataprocessor
+import experiment as ex
+import model
+import trainer
+import utils
+
+parser = argparse.ArgumentParser(description='iCarl2.0')
+parser.add_argument('--batch-size', type=int, default=32, metavar='N',
+                    help='input batch size for training (default: 64)')
+parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
+                    help='learning rate (default: 2.0)')
+parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
+                    help='Decrease learning rate at these epochs.')
+parser.add_argument('--gammas', type=float, nargs='+', default=[0.2, 0.2, 0.2],
+                    help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')
+parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+                    help='SGD momentum (default: 0.9)')
+parser.add_argument('--no-cuda', action='store_true', default=False,
+                    help='disables CUDA training')
+parser.add_argument('--pretrain', action='store_true', default=False,
+                    help='Pretrain the model on CIFAR dataset?')
+parser.add_argument('--load-ram', action='store_true', default=False,
+                    help='Load data in ram')
+parser.add_argument('--debug', action='store_true', default=True,
+                    help='Debug messages')
+parser.add_argument('--seed', type=int, default=2323,
+                    help='Seeds values to be used')
+parser.add_argument('--log-interval', type=int, default=5, metavar='N',
+                    help='how many batches to wait before logging training status')
+parser.add_argument('--model-type', default="resnet",
+                    help='model type to be used. Example : resnet32, resnet20, densenet, test')
+parser.add_argument('--name', default="noname",
+                    help='Name of the experiment')
+parser.add_argument('--output-dir', default="../",
+                    help='Directory to store the results; a new folder "DDMMYYYY" will be created '
+                         'in the specified directory to save the results.')
+parser.add_argument('--decay', type=float, default=0.00001, help='Weight decay (L2 penalty).')
+parser.add_argument('--epochs', type=int, default=40, help='Number of epochs for each increment')
+parser.add_argument('--dataset', default="document", help='Dataset to be used; example CIFAR, MNIST')
+parser.add_argument('--loader', default="hdd", help='Dataset to be used; example CIFAR, MNIST')
+parser.add_argument("-i", "--data-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
+                    help="input Directory of train data")
+parser.add_argument("-v", "--validation-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
+                    help="input Directory of val data")
+
+args = parser.parse_args()
+
+# Define an experiment.
+my_experiment = ex.experiment(args.name, args, args.output_dir)
+
+# Add logging support
+logger = utils.utils.setup_logger(my_experiment.path)
+
+args.cuda = not args.no_cuda and torch.cuda.is_available()
+
+dataset = dataprocessor.DatasetFactory.get_dataset(args.data_dirs, args.dataset)
+
+dataset_val = dataprocessor.DatasetFactory.get_dataset(args.validation_dirs, args.dataset)
+
+# Fix the seed.
+seed = args.seed
+torch.manual_seed(seed)
+if args.cuda:
+    torch.cuda.manual_seed(seed)
+
+train_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset.myData,
+                                                              transform=dataset.train_transform,
+                                                              cuda=args.cuda)
+# Loader used for training data
+val_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset_val.myData,
+                                                            transform=dataset.test_transform,
+                                                            cuda=args.cuda)
+kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
+
+# Iterator to iterate over training data.
+train_iterator = torch.utils.data.DataLoader(train_dataset_loader,
+                                             batch_size=args.batch_size, shuffle=True, **kwargs)
+# Iterator to iterate over training data.
+val_iterator = torch.utils.data.DataLoader(val_dataset_loader,
+                                           batch_size=args.batch_size, shuffle=True, **kwargs)
+
+# Get the required model
+myModel = model.ModelFactory.get_model(args.model_type, args.dataset)
+if args.cuda:
+    myModel.cuda()
+
+# Should I pretrain the model on CIFAR?
+if args.pretrain:
+    trainset = dataprocessor.DatasetFactory.get_dataset(None, "CIFAR")
+    train_iterator_cifar = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
+
+    # Define the optimizer used in the experiment
+    cifar_optimizer = torch.optim.SGD(myModel.parameters(), args.lr, momentum=args.momentum,
+                                      weight_decay=args.decay, nesterov=True)
+
+    # Trainer object used for training
+    cifar_trainer = trainer.CIFARTrainer(train_iterator_cifar, myModel, args.cuda, cifar_optimizer)
+
+    for epoch in range(0, 70):
+        logger.info("Epoch : %d", epoch)
+        cifar_trainer.update_lr(epoch, [30, 45, 60], args.gammas)
+        cifar_trainer.train(epoch)
+
+    # Freeze the model
+    counter = 0
+    for name, param in myModel.named_parameters():
+        # Getting the length of total layers so I can freeze x% of layers
+        gen_len = sum(1 for _ in myModel.parameters())
+        if counter < int(gen_len * 0.5):
+            param.requires_grad = False
+            logger.warning(name)
+        else:
+            logger.info(name)
+        counter += 1
+
+# Define the optimizer used in the experiment
+optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, myModel.parameters()), args.lr,
+                            momentum=args.momentum,
+                            weight_decay=args.decay, nesterov=True)
+
+# Trainer object used for training
+my_trainer = trainer.Trainer(train_iterator, myModel, args.cuda, optimizer)
+
+# Evaluator
+my_eval = trainer.EvaluatorFactory.get_evaluator("rmse", args.cuda)
+# Running epochs_class epochs
+for epoch in range(0, args.epochs):
+    logger.info("Epoch : %d", epoch)
+    my_trainer.update_lr(epoch, args.schedule, args.gammas)
+    my_trainer.train(epoch)
+    my_eval.evaluate(my_trainer.model, val_iterator)
+
+torch.save(myModel.state_dict(), my_experiment.path + args.dataset + "_" + args.model_type+ ".pb")
+my_experiment.store_json()

+ 2 - 0
Recursive-CNNs/trainer/__init__.py

@@ -0,0 +1,2 @@
+from trainer.evaluator import *
+from trainer.trainer import *

二進制
Recursive-CNNs/trainer/__pycache__/__init__.cpython-38.pyc


二進制
Recursive-CNNs/trainer/__pycache__/evaluator.cpython-38.pyc


二進制
Recursive-CNNs/trainer/__pycache__/trainer.cpython-38.pyc


+ 63 - 0
Recursive-CNNs/trainer/evaluator.py

@@ -0,0 +1,63 @@
+''' Incremental-Classifier Learning 
+ Authors : Khurram Javed, Muhammad Talha Paracha
+ Maintainer : Khurram Javed
+ Lab : TUKL-SEECS R&D Lab
+ Email : 14besekjaved@seecs.edu.pk '''
+
+import logging
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from torchnet.meter import confusionmeter
+from tqdm import tqdm
+
+logger = logging.getLogger('iCARL')
+
+
+class EvaluatorFactory():
+    '''
+    This class is used to get different versions of evaluators
+    '''
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def get_evaluator(testType="rmse", cuda=True):
+        if testType == "rmse":
+            return DocumentMseEvaluator(cuda)
+
+
+
+class DocumentMseEvaluator():
+    '''
+    Evaluator class for softmax classification 
+    '''
+    def __init__(self, cuda):
+        self.cuda = cuda
+
+
+    def evaluate(self, model, iterator):
+        model.eval()
+        lossAvg = None
+        with torch.no_grad():
+            for img, target in tqdm(iterator):
+                if self.cuda:
+                    img, target = img.cuda(), target.cuda()
+
+                response = model(Variable(img))
+                # print (response[0])
+                # print (target[0])
+                loss = F.mse_loss(response, Variable(target.float()))
+                loss = torch.sqrt(loss)
+                if lossAvg is None:
+                    lossAvg = loss
+                else:
+                    lossAvg += loss
+                # logger.debug("Cur loss %s", str(loss))
+
+        lossAvg /= len(iterator)
+        logger.info("Avg Val Loss %s", str((lossAvg).cpu().data.numpy()))
+
+

+ 110 - 0
Recursive-CNNs/trainer/trainer.py

@@ -0,0 +1,110 @@
+''' Pytorch Recursive CNN Trainer
+ Authors : Khurram Javed
+ Maintainer : Khurram Javed
+ Lab : TUKL-SEECS R&D Lab
+ Email : 14besekjaved@seecs.edu.pk '''
+
+from __future__ import print_function
+
+import logging
+
+from torch.autograd import Variable
+
+logger = logging.getLogger('iCARL')
+import torch.nn.functional as F
+import torch
+from tqdm import tqdm
+
+
+class GenericTrainer:
+    '''
+    Base class for trainer; to implement a new training routine, inherit from this. 
+    '''
+
+    def __init__(self):
+        pass
+
+
+
+class Trainer(GenericTrainer):
+    def __init__(self, train_iterator, model, cuda, optimizer):
+        super().__init__()
+        self.cuda = cuda
+        self.train_iterator = train_iterator
+        self.model = model
+        self.optimizer = optimizer
+
+    def update_lr(self, epoch, schedule, gammas):
+        for temp in range(0, len(schedule)):
+            if schedule[temp] == epoch:
+                for param_group in self.optimizer.param_groups:
+                    self.current_lr = param_group['lr']
+                    param_group['lr'] = self.current_lr * gammas[temp]
+                    logger.debug("Changing learning rate from %0.9f to %0.9f", self.current_lr,
+                                 self.current_lr * gammas[temp])
+                    self.current_lr *= gammas[temp]
+
+    def train(self, epoch):
+        self.model.train()
+        lossAvg = None
+        for img, target in tqdm(self.train_iterator):
+            if self.cuda:
+                img, target = img.cuda(), target.cuda()
+            self.optimizer.zero_grad()
+            response = self.model(Variable(img))
+            # print (response[0])
+            # print (target[0])
+            loss = F.mse_loss(response, Variable(target.float()))
+            loss = torch.sqrt(loss)
+            if lossAvg is None:
+                lossAvg = loss
+            else:
+                lossAvg += loss
+            # logger.debug("Cur loss %s", str(loss))
+            loss.backward()
+            self.optimizer.step()
+
+        lossAvg /= len(self.train_iterator)
+        logger.info("Avg Loss %s", str((lossAvg).cpu().data.numpy()))
+
+
+class CIFARTrainer(GenericTrainer):
+    def __init__(self, train_iterator, model, cuda, optimizer):
+        super().__init__()
+        self.cuda = cuda
+        self.train_iterator = train_iterator
+        self.model = model
+        self.optimizer = optimizer
+        self.criterion = torch.nn.CrossEntropyLoss()
+
+    def update_lr(self, epoch, schedule, gammas):
+        for temp in range(0, len(schedule)):
+            if schedule[temp] == epoch:
+                for param_group in self.optimizer.param_groups:
+                    self.current_lr = param_group['lr']
+                    param_group['lr'] = self.current_lr * gammas[temp]
+                    logger.debug("Changing learning rate from %0.9f to %0.9f", self.current_lr,
+                                 self.current_lr * gammas[temp])
+                    self.current_lr *= gammas[temp]
+
+    def train(self, epoch):
+        self.model.train()
+        train_loss = 0
+        correct = 0
+        total = 0
+        for inputs, targets in tqdm(self.train_iterator):
+            if self.cuda:
+                inputs, targets = inputs.cuda(), targets.cuda()
+            self.optimizer.zero_grad()
+            outputs = self.model(Variable(inputs), pretrain=True)
+            loss = self.criterion(outputs, Variable(targets))
+            loss.backward()
+            self.optimizer.step()
+
+            train_loss += loss.item()
+            _, predicted = outputs.max(1)
+            total += targets.size(0)
+            correct += predicted.eq(targets).sum().item()
+
+        logger.info("Accuracy : %s", str((correct * 100) / total))
+        return correct / total

+ 2 - 0
Recursive-CNNs/utils/__init__.py

@@ -0,0 +1,2 @@
+from utils import utils
+from utils import colorer

二進制
Recursive-CNNs/utils/__pycache__/__init__.cpython-38.pyc


二進制
Recursive-CNNs/utils/__pycache__/colorer.cpython-38.pyc


二進制
Recursive-CNNs/utils/__pycache__/utils.cpython-38.pyc


+ 114 - 0
Recursive-CNNs/utils/colorer.py

@@ -0,0 +1,114 @@
+#!/usr/bin/env python
+# encoding: utf-8
+import logging
+
+# Source : # http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output
+
+# now we patch Python code to add color support to logging.StreamHandler
+def add_coloring_to_emit_windows(fn):
+    # add methods we need to the class
+    def _out_handle(self):
+        import ctypes
+        return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
+
+    out_handle = property(_out_handle)
+
+    def _set_color(self, code):
+        import ctypes
+        # Constants from the Windows API
+        self.STD_OUTPUT_HANDLE = -11
+        hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
+        ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
+
+    setattr(logging.StreamHandler, '_set_color', _set_color)
+
+    def new(*args):
+        FOREGROUND_BLUE = 0x0001  # text color contains blue.
+        FOREGROUND_GREEN = 0x0002  # text color contains green.
+        FOREGROUND_RED = 0x0004  # text color contains red.
+        FOREGROUND_INTENSITY = 0x0008  # text color is intensified.
+        FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
+        # winbase.h
+        STD_INPUT_HANDLE = -10
+        STD_OUTPUT_HANDLE = -11
+        STD_ERROR_HANDLE = -12
+
+        # wincon.h
+        FOREGROUND_BLACK = 0x0000
+        FOREGROUND_BLUE = 0x0001
+        FOREGROUND_GREEN = 0x0002
+        FOREGROUND_CYAN = 0x0003
+        FOREGROUND_RED = 0x0004
+        FOREGROUND_MAGENTA = 0x0005
+        FOREGROUND_YELLOW = 0x0006
+        FOREGROUND_GREY = 0x0007
+        FOREGROUND_INTENSITY = 0x0008  # foreground color is intensified.
+
+        BACKGROUND_BLACK = 0x0000
+        BACKGROUND_BLUE = 0x0010
+        BACKGROUND_GREEN = 0x0020
+        BACKGROUND_CYAN = 0x0030
+        BACKGROUND_RED = 0x0040
+        BACKGROUND_MAGENTA = 0x0050
+        BACKGROUND_YELLOW = 0x0060
+        BACKGROUND_GREY = 0x0070
+        BACKGROUND_INTENSITY = 0x0080  # background color is intensified.
+
+        levelno = args[1].levelno
+        if (levelno >= 50):
+            color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
+        elif (levelno >= 40):
+            color = FOREGROUND_RED | FOREGROUND_INTENSITY
+        elif (levelno >= 30):
+            color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
+        elif (levelno >= 20):
+            color = FOREGROUND_GREEN
+        elif (levelno >= 10):
+            color = FOREGROUND_MAGENTA
+        else:
+            color = FOREGROUND_WHITE
+        args[0]._set_color(color)
+
+        ret = fn(*args)
+        args[0]._set_color(FOREGROUND_WHITE)
+        # print "after"
+        return ret
+
+    return new
+
+
+def add_coloring_to_emit_ansi(fn):
+    # add methods we need to the class
+    def new(*args):
+        levelno = args[1].levelno
+        if (levelno >= 50):
+            color = '\x1b[31m'  # red
+        elif (levelno >= 40):
+            color = '\x1b[31m'  # red
+        elif (levelno >= 30):
+            color = '\x1b[33m'  # yellow
+        elif (levelno >= 20):
+            color = '\x1b[32m'  # green
+        elif (levelno >= 10):
+            color = '\x1b[35m'  # pink
+        else:
+            color = '\x1b[0m'  # normal
+        args[1].msg = color + args[1].msg + '\x1b[0m'  # normal
+        # print "after"
+        return fn(*args)
+
+    return new
+
+
+import platform
+
+if platform.system() == 'Windows':
+    # Windows does not support ANSI escapes and we are using API calls to set the console color
+    logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
+else:
+    # all non-Windows platforms are supporting ANSI escapes so we use them
+    logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
+    # log = logging.getLogger()
+    # log.addFilter(log_filter())
+    # //hdlr = logging.StreamHandler()
+    # //hdlr.setFormatter(formatter())

+ 296 - 0
Recursive-CNNs/utils/utils.py

@@ -0,0 +1,296 @@
+''' Document Localization using Recursive CNN
+ Maintainer : Khurram Javed
+ Email : kjaved@ualberta.ca '''
+
+import random
+
+import cv2
+import numpy as np
+import Polygon
+
+def unison_shuffled_copies(a, b):
+    assert len(a) == len(b)
+    p = np.random.permutation(len(a))
+    return a[p], b[p]
+
+
+def intersection(a, b, img):
+    img1 = np.zeros_like(img)
+
+    cv2.fillConvexPoly(img1, a, (255, 0, 0))
+    img1 = np.sum(img1, axis=2)
+
+    img1 = img1 / 255
+
+    img2 = np.zeros_like(img)
+    cv2.fillConvexPoly(img2, b, (255, 0, 0))
+    img2 = np.sum(img2, axis=2)
+    img2 = img2 / 255
+
+    inte = img1 * img2
+    union = np.logical_or(img1, img2)
+    iou = np.sum(inte) / np.sum(union)
+    print(iou)
+    return iou
+
+
+def intersection_with_correction(a, b, img):
+    img1 = np.zeros_like(img)
+    cv2.fillConvexPoly(img1, a, (255, 0, 0))
+
+    img2 = np.zeros_like(img)
+    cv2.fillConvexPoly(img2, b, (255, 0, 0))
+    min_x = min(a[0][0], a[1][0], a[2][0], a[3][0])
+    min_y = min(a[0][1], a[1][1], a[2][1], a[3][1])
+    max_x = max(a[0][0], a[1][0], a[2][0], a[3][0])
+    max_y = max(a[0][1], a[1][1], a[2][1], a[3][1])
+
+    dst = np.array(((min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)))
+    mat = cv2.getPerspectiveTransform(a.astype(np.float32), dst.astype(np.float32))
+    img1 = cv2.warpPerspective(img1, mat, tuple((img.shape[0], img.shape[1])))
+    img2 = cv2.warpPerspective(img2, mat, tuple((img.shape[0], img.shape[1])))
+
+    img1 = np.sum(img1, axis=2)
+    img1 = img1 / 255
+    img2 = np.sum(img2, axis=2)
+    img2 = img2 / 255
+
+    inte = img1 * img2
+    union = np.logical_or(img1, img2)
+    iou = np.sum(inte) / np.sum(union)
+    return iou
+
+def intersection_with_correction_smart_doc_implementation(gt, prediction, img):
+
+    # Reference : https://github.com/jchazalon/smartdoc15-ch1-eval
+
+    gt = sort_gt(gt)
+    prediction = sort_gt(prediction)
+    img1 = np.zeros_like(img)
+    cv2.fillConvexPoly(img1, gt, (255, 0, 0))
+
+    target_width = 2100
+    target_height = 2970
+    # Referential: (0,0) at TL, x > 0 toward right and y > 0 toward bottom
+    # Corner order: TL, BL, BR, TR
+    # object_coord_target = np.float32([[0, 0], [0, target_height], [target_width, target_height], [target_width, 0]])
+    object_coord_target = np.array(np.float32([[0, 0], [target_width, 0], [target_width, target_height],[0, target_height]]))
+    # print (gt, object_coord_target)
+    H = cv2.getPerspectiveTransform(gt.astype(np.float32).reshape(-1, 1, 2), object_coord_target.reshape(-1, 1, 2))
+
+    # 2/ Apply to test result to project in target referential
+    test_coords = cv2.perspectiveTransform(prediction.astype(np.float32).reshape(-1, 1, 2), H)
+
+    # 3/ Compute intersection between target region and test result region
+    # poly = Polygon.Polygon([(0,0),(1,0),(0,1)])
+    poly_target = Polygon.Polygon(object_coord_target.reshape(-1, 2))
+    poly_test = Polygon.Polygon(test_coords.reshape(-1, 2))
+    poly_inter = poly_target & poly_test
+
+    area_target = poly_target.area()
+    area_test = poly_test.area()
+    area_inter = poly_inter.area()
+
+    area_union = area_test + area_target - area_inter
+    # Little hack to cope with float precision issues when dealing with polygons:
+    #   If intersection area is close enough to target area or GT area, but slighlty >,
+    #   then fix it, assuming it is due to rounding issues.
+    area_min = min(area_target, area_test)
+    if area_min < area_inter and area_min * 1.0000000001 > area_inter:
+        area_inter = area_min
+        print("Capping area_inter.")
+
+    jaccard_index = area_inter / area_union
+    return jaccard_index
+
+
+
+def __rotateImage(image, angle):
+    rot_mat = cv2.getRotationMatrix2D((image.shape[1] / 2, image.shape[0] / 2), angle, 1)
+    result = cv2.warpAffine(image, rot_mat, (image.shape[1], image.shape[0]), flags=cv2.INTER_LINEAR)
+    return result, rot_mat
+
+
+def rotate(img, gt, angle):
+    img, mat = __rotateImage(img, angle)
+    gt = gt.astype(np.float64)
+    for a in range(0, 4):
+        gt[a] = np.dot(mat[..., 0:2], gt[a]) + mat[..., 2]
+    return img, gt
+
+
+def random_crop(img, gt):
+    ptr1 = (min(gt[0][0], gt[1][0], gt[2][0], gt[3][0]),
+            min(gt[0][1], gt[1][1], gt[2][1], gt[3][1]))
+
+    ptr2 = ((max(gt[0][0], gt[1][0], gt[2][0], gt[3][0]),
+             max(gt[0][1], gt[1][1], gt[2][1], gt[3][1])))
+
+    start_x = np.random.randint(0, int(max(ptr1[0] - 1, 1)))
+    start_y = np.random.randint(0, int(max(ptr1[1] - 1, 1)))
+
+    end_x = np.random.randint(int(min(ptr2[0] + 1, img.shape[1] - 1)), img.shape[1])
+    end_y = np.random.randint(int(min(ptr2[1] + 1, img.shape[0] - 1)), img.shape[0])
+
+    img = img[start_y:end_y, start_x:end_x]
+
+    myGt = gt - (start_x, start_y)
+    myGt = myGt * (1.0 / img.shape[1], 1.0 / img.shape[0])
+
+    myGtTemp = myGt * myGt
+    sum_array = myGtTemp.sum(axis=1)
+    tl_index = np.argmin(sum_array)
+    tl = myGt[tl_index]
+    tr = myGt[(tl_index + 1) % 4]
+    br = myGt[(tl_index + 2) % 4]
+    bl = myGt[(tl_index + 3) % 4]
+
+    return img, (tl, tr, br, bl)
+
+
+def get_corners(img, gt):
+    gt = gt.astype(int)
+    list_of_points = {}
+    myGt = gt
+
+    myGtTemp = myGt * myGt
+    sum_array = myGtTemp.sum(axis=1)
+
+    tl_index = np.argmin(sum_array)
+    tl = myGt[tl_index]
+    tr = myGt[(tl_index + 1) % 4]
+    br = myGt[(tl_index + 2) % 4]
+    bl = myGt[(tl_index + 3) % 4]
+
+    list_of_points["tr"] = tr
+    list_of_points["tl"] = tl
+    list_of_points["br"] = br
+    list_of_points["bl"] = bl
+    gt_list = []
+    images_list = []
+    for k, v in list_of_points.items():
+
+        if (k == "tl"):
+            cords_x = __get_cords(v[0], 0, list_of_points["tr"][0], buf=10, size=abs(list_of_points["tr"][0] - v[0]))
+            cords_y = __get_cords(v[1], 0, list_of_points["bl"][1], buf=10, size=abs(list_of_points["bl"][1] - v[1]))
+            # print cords_y, cords_x
+            gt = (v[0] - cords_x[0], v[1] - cords_y[0])
+
+            cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
+
+        if (k == "tr"):
+            cords_x = __get_cords(v[0], list_of_points["tl"][0], img.shape[1], buf=10,
+                                  size=abs(list_of_points["tl"][0] - v[0]))
+            cords_y = __get_cords(v[1], 0, list_of_points["br"][1], buf=10, size=abs(list_of_points["br"][1] - v[1]))
+            # print cords_y, cords_x
+            gt = (v[0] - cords_x[0], v[1] - cords_y[0])
+
+            cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
+
+        if (k == "bl"):
+            cords_x = __get_cords(v[0], 0, list_of_points["br"][0], buf=10,
+                                  size=abs(list_of_points["br"][0] - v[0]))
+            cords_y = __get_cords(v[1], list_of_points["tl"][1], img.shape[0], buf=10,
+                                  size=abs(list_of_points["tl"][1] - v[1]))
+            # print cords_y, cords_x
+            gt = (v[0] - cords_x[0], v[1] - cords_y[0])
+
+            cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
+
+        if (k == "br"):
+            cords_x = __get_cords(v[0], list_of_points["bl"][0], img.shape[1], buf=10,
+                                  size=abs(list_of_points["bl"][0] - v[0]))
+            cords_y = __get_cords(v[1], list_of_points["tr"][1], img.shape[0], buf=10,
+                                  size=abs(list_of_points["tr"][1] - v[1]))
+            # print cords_y, cords_x
+            gt = (v[0] - cords_x[0], v[1] - cords_y[0])
+
+            cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
+
+        # cv2.circle(cut_image, gt, 2, (255, 0, 0), 6)
+        mah_size = cut_image.shape
+        cut_image = cv2.resize(cut_image, (300, 300))
+        a = int(gt[0] * 300 / mah_size[1])
+        b = int(gt[1] * 300 / mah_size[0])
+        images_list.append(cut_image)
+        gt_list.append((a, b))
+    return images_list, gt_list
+
+
+def __get_cords(cord, min_start, max_end, size=299, buf=5, random_scale=True):
+    # size = max(abs(cord-min_start), abs(cord-max_end))
+    iter = 0
+    if (random_scale):
+        size /= random.randint(1, 4)
+    while (max_end - min_start) < size:
+        size = size * .9
+    temp = -1
+    while (temp < 1):
+        temp = random.normalvariate(size / 2, size / 6)
+    x_start = max(cord - temp, min_start)
+    x_start = int(x_start)
+    if x_start >= cord:
+        print("XSTART AND CORD", x_start, cord)
+    assert (x_start < cord)
+    while ((x_start < min_start) or (x_start + size > max_end) or (x_start + size <= cord)):
+        # x_start = random.randint(int(min(max(min_start, int(cord - size + buf)), cord - buf - 1)), cord - buf)
+        temp = -1
+        while (temp < 1):
+            temp = random.normalvariate(size / 2, size / 6)
+        temp = max(temp, 1)
+        x_start = max(cord - temp, min_start)
+        x_start = int(x_start)
+        size = size * .995
+        iter += 1
+        if (iter == 1000):
+            x_start = int(cord - (size / 2))
+            print("Gets here")
+            break
+    assert (x_start >= 0)
+    if x_start >= cord:
+        print("XSTART AND CORD", x_start, cord)
+    assert (x_start < cord)
+    assert (x_start + size <= max_end)
+    assert (x_start + size > cord)
+    return (x_start, int(x_start + size))
+
+
+def setup_logger(path):
+    import logging
+    logger = logging.getLogger('iCARL')
+    logger.setLevel(logging.DEBUG)
+
+    fh = logging.FileHandler(path + ".log")
+    fh.setLevel(logging.DEBUG)
+
+    fh2 = logging.FileHandler("../temp.log")
+    fh2.setLevel(logging.DEBUG)
+
+    ch = logging.StreamHandler()
+    ch.setLevel(logging.DEBUG)
+
+    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    fh.setFormatter(formatter)
+    fh2.setFormatter(formatter)
+
+    logger.addHandler(fh)
+    logger.addHandler(fh2)
+    logger.addHandler(ch)
+    return logger
+
+
+def sort_gt(gt):
+    '''
+    Sort the ground truth labels so that TL corresponds to the label with smallest distance from O
+    :param gt: 
+    :return: sorted gt
+    '''
+    myGtTemp = gt * gt
+    sum_array = myGtTemp.sum(axis=1)
+    tl_index = np.argmin(sum_array)
+    tl = gt[tl_index]
+    tr = gt[(tl_index + 1) % 4]
+    br = gt[(tl_index + 2) % 4]
+    bl = gt[(tl_index + 3) % 4]
+
+    return np.asarray((tl, tr, br, bl))

+ 54 - 0
doc_clean_up/.vscode/launch.json

@@ -0,0 +1,54 @@
+{
+    // Use IntelliSense to learn about possible attributes.
+    // Hover to view descriptions of existing attributes.
+    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "name": "convert_model_to_tflite.py",
+            "type": "python",
+            "request": "launch",
+            "program": "${workspaceFolder}/convert_model_to_tflite.py",
+            "console": "integratedTerminal",
+            "args": ["--ckpt_path=/Users/chenlong/Desktop/models/model_ssmi_all_3.pt"],
+            // "justMyCode": false
+        },
+        {
+            "name": "tflite_infer.py",
+            "type": "python",
+            "request": "launch",
+            "program": "${workspaceFolder}/tflite_infer.py",
+            "console": "integratedTerminal"
+        },
+        {
+            "name": "train",
+            "type": "python",
+            "request": "launch",
+            "program": "${workspaceFolder}/train.py",
+            "console": "integratedTerminal",
+            "args": [
+                "--develop=true",
+                "--lr=1e-3",
+                "--retrain=true",
+            ]
+            // "justMyCode": false
+        },
+        {
+            "name": "generate_dataset",
+            "type": "python",
+            "request": "launch",
+            "program": "${workspaceFolder}/generate_dataset.py",
+            "console": "integratedTerminal",
+            // "justMyCode": false
+        },
+        {
+            "name": "infer",
+            "type": "python",
+            "request": "launch",
+            "program": "${workspaceFolder}/infer.py",
+            "args": ["--ckpt_path=/Users/chenlong/Desktop/models/model_ssmi_all_5.pt"],
+            "console": "integratedTerminal",
+            // "justMyCode": false
+        }
+    ]
+}

+ 96 - 0
doc_clean_up/MS_SSIM_L1_loss.py

@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Dec  3 00:28:15 2020
+
+@author: Yunpeng Li, Tianjin University
+"""
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MS_SSIM_L1_LOSS(nn.Module):
+    # Have to use cuda, otherwise the speed is too slow.
+    def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
+                 data_range = 1.0,
+                 K=(0.01, 0.03),
+                 alpha=0.025,
+                 compensation=200.0,
+                 device=torch.device('cpu')):
+        super(MS_SSIM_L1_LOSS, self).__init__()
+        self.DR = data_range
+        self.C1 = (K[0] * data_range) ** 2
+        self.C2 = (K[1] * data_range) ** 2
+        self.pad = int(2 * gaussian_sigmas[-1])
+        self.alpha = alpha
+        self.compensation=compensation
+        filter_size = int(4 * gaussian_sigmas[-1] + 1)
+        g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size))
+        for idx, sigma in enumerate(gaussian_sigmas):
+            # r0,g0,b0,r1,g1,b1,...,rM,gM,bM
+            g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
+            g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
+            g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
+        self.g_masks = g_masks.to(device)
+
+    def _fspecial_gauss_1d(self, size, sigma):
+        """Create 1-D gauss kernel
+        Args:
+            size (int): the size of gauss kernel
+            sigma (float): sigma of normal distribution
+
+        Returns:
+            torch.Tensor: 1D kernel (size)
+        """
+        coords = torch.arange(size).to(dtype=torch.float)
+        coords -= size // 2
+        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
+        g /= g.sum()
+        return g.reshape(-1)
+
+    def _fspecial_gauss_2d(self, size, sigma):
+        """Create 2-D gauss kernel
+        Args:
+            size (int): the size of gauss kernel
+            sigma (float): sigma of normal distribution
+
+        Returns:
+            torch.Tensor: 2D kernel (size x size)
+        """
+        gaussian_vec = self._fspecial_gauss_1d(size, sigma)
+        return torch.outer(gaussian_vec, gaussian_vec)
+
+    def forward(self, x, y):
+        b, c, h, w = x.shape
+        mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad)
+        muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad)
+
+        mux2 = mux * mux
+        muy2 = muy * muy
+        muxy = mux * muy
+
+        sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2
+        sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2
+        sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy
+
+        # l(j), cs(j) in MS-SSIM
+        l  = (2 * muxy    + self.C1) / (mux2    + muy2    + self.C1)  # [B, 15, H, W]
+        cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)
+
+        lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :]
+        PIcs = cs.prod(dim=1)
+
+        loss_ms_ssim = 1 - lM*PIcs  # [B, H, W]
+
+        loss_l1 = F.l1_loss(x, y, reduction='none')  # [B, 3, H, W]
+        # average l1 loss in 3 channels
+        gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3),
+                               groups=3, padding=self.pad).mean(1)  # [B, H, W]
+
+        loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR
+        loss_mix = self.compensation*loss_mix
+
+        return loss_mix.mean()
+

+ 75 - 0
doc_clean_up/convert_model_to_tflite.py

@@ -0,0 +1,75 @@
+#!/usr/bin/env python
+import argparse
+import os
+import shutil
+import onnx
+import torch
+import torch.backends._nnapi.prepare
+import torch.utils.bundled_inputs
+import torch.utils.mobile_optimizer
+from onnx_tf.backend import prepare
+import tensorflow as tf
+from model import M64ColorNet
+from torch.nn.utils import prune
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--ckpt_path',
+                    type=str,
+                    help='This is the path where to store the ckpt file',
+                    default="output/model.pt")
+
+
+def convert_to_tflite(out_dir: str, model: torch.nn.Module):
+    dummy_input = torch.randn((1, 3, 256, 256))
+    onnx_path = f"{out_dir}/converted.onnx"
+    torch.onnx.export(model, dummy_input, onnx_path, verbose=True,
+                      input_names=['input'], output_names=['output'])
+
+    tf_path = f"{out_dir}/tf_model"
+    onnx_model = onnx.load(onnx_path)
+    # prepare function converts an ONNX model to an internel representation
+    # of the computational graph called TensorflowRep and returns
+    # the converted representation.
+    tf_rep = prepare(onnx_model)  # creating TensorflowRep object
+    # export_graph function obtains the graph proto corresponding to the ONNX
+    # model associated with the backend representation and serializes
+    # to a protobuf file.
+    tf_rep.export_graph(tf_path)
+
+    converter = tf.lite.TFLiteConverter.from_saved_model(tf_path)
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    tf_lite_model = converter.convert()
+    tflite_path = f"{out_dir}/doc_clean.tflite"
+    with open(tflite_path, 'wb') as f:
+        f.write(tf_lite_model)
+
+
+def convert_to_tflite_with_tiny(out_dir: str, fileName:str, model: torch.nn.Module):
+    from tinynn.converter import TFLiteConverter
+    dummy_input = torch.rand((1, 3, 256, 256))
+
+    # output_path = os.path.join(out_dir, 'out', 'mbv1_224.tflite')
+    tflite_path = f"{out_dir}/{fileName}"
+
+    # When converting quantized models, please ensure the quantization backend is set.
+    # torch.backends.quantized.engine = 'qnnpack'
+
+    # The code section below is used to convert the model to the TFLite format
+    # If you want perform dynamic quantization on the float models,
+    # you may pass the following arguments.
+    #   `quantize_target_type='int8', hybrid_quantization_from_float=True, hybrid_per_channel=False`
+    # As for static quantization (e.g. quantization-aware training and post-training quantization),
+    # please refer to the code examples in the `examples/quantization` folder.
+    converter = TFLiteConverter(model, dummy_input, tflite_path)
+    converter.convert()
+
+if __name__ == "__main__":
+    out_dir = "output_tflite"
+    shutil.rmtree(out_dir, ignore_errors=True)
+    os.mkdir(out_dir)
+    args = parser.parse_args()
+    model, _, _, _, ssim, psnr = M64ColorNet.load_trained_model(args.ckpt_path)
+    name = os.path.basename(args.ckpt_path).split(".")[0]
+    fileName = f"ssim_{round(ssim, 2)}_psnr_{round(psnr, 2)}_{name}.tflite"
+    # convert_to_tflite(out_dir, model)
+    convert_to_tflite_with_tiny(out_dir, fileName, model)

+ 80 - 0
doc_clean_up/dataset.py

@@ -0,0 +1,80 @@
+from torch.utils.data import Dataset
+from PIL import Image
+from torchvision import transforms
+from typing import List, Tuple
+import imgaug.augmenters as iaa
+import numpy as np
+from sklearn.model_selection import train_test_split
+
+class UnNormalize(object):
+    def __init__(self, mean, std):
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, tensor):
+        """
+        Args:
+            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+        Returns:
+            Tensor: Normalized image.
+        """
+        for t, m, s in zip(tensor, self.mean, self.std):
+            t.mul_(s).add_(m)
+            # The normalize code -> t.sub_(m).div_(s)
+        return tensor
+
+class DocCleanDataset(Dataset):
+
+    @staticmethod
+    def prepareDataset(dataset:str, shuffle=True):
+        # imgs_dir = "dataset/raw_data/imgs_Trainblocks"
+        with open(f"{dataset}/train_block_names.txt") as train_block_names_file:
+            image_names = train_block_names_file.read().splitlines()
+            train_img_names, eval_img_names, _, _ = train_test_split(
+                image_names, image_names, test_size=0.2, random_state=1, shuffle=shuffle)
+            return train_img_names, eval_img_names, dataset
+
+    def __init__(self, img_names: List[str], imgs_dir: str, normalized_tuple: Tuple[List[float], List[float]] = None, dev=False, img_aug=False):
+        if dev:
+            num = int(len(img_names) * 0.01)
+            img_names = img_names[0:num]
+        self.img_names = img_names
+        self.imgs_dir = imgs_dir
+        if normalized_tuple:
+            mean, std = normalized_tuple
+            self.normalized = transforms.Compose([
+                transforms.ToTensor(), 
+                transforms.Normalize(mean=mean, std=std)
+            ])
+            self.aug_seq = iaa.Sometimes(0.7, iaa.OneOf([
+                    iaa.SaltAndPepper(p=(0.0, 0.05)),
+                    iaa.imgcorruptlike.MotionBlur(severity=2),
+                    iaa.SigmoidContrast(gain=(3, 10), cutoff=(0.4, 0.6)),
+                    iaa.imgcorruptlike.JpegCompression(severity=2),
+                    iaa.GammaContrast((0.5, 2.0)),
+                    iaa.LogContrast(gain=(0.5, 0.9)),
+                    iaa.GaussianBlur(sigma=(0, 1)),
+                    iaa.imgcorruptlike.SpeckleNoise(severity=1),
+                    iaa.AdditiveGaussianNoise(scale=(0.03*255, 0.2*255), per_channel=True),
+                    iaa.Add((-20, 20), per_channel=0.5),
+                    iaa.AddToBrightness((-30, 30))
+                ]))
+        self.img_aug = img_aug
+        self.toTensor = transforms.ToTensor()
+
+    def __len__(self):
+        return len(self.img_names)
+
+    def __getitem__(self, index):
+        img = Image.open(f"{self.imgs_dir}/{self.img_names[index]}")
+        gt = Image.open(f"{self.imgs_dir}/gt{self.img_names[index]}")
+        if hasattr(self, 'normalized'):
+            img_np = np.array(img)
+            if self.img_aug == True:
+                img_np = self.aug_seq.augment_images([np.array(img)])[0]
+            normalized_img = self.normalized(img_np)
+            img = self.toTensor(img_np)
+        else:
+            img = self.toTensor(img)
+            normalized_img = img
+        return img, normalized_img, self.toTensor(gt)

+ 265 - 0
doc_clean_up/generate_dataset.py

@@ -0,0 +1,265 @@
+import os
+from tqdm import tqdm
+import cv2
+import numpy as np
+import albumentations as A
+import random
+import shutil
+import argparse
+
+# path parameters
+parser = argparse.ArgumentParser()
+
+parser.add_argument('--data_dir',
+                    type=str,
+                    help='Raw training data.',
+                    default="raw_data")
+
+transform = A.Compose([
+    A.OneOf([
+            A.ISONoise(p=0.4),
+            A.JpegCompression(quality_lower=50, quality_upper=70,
+                              always_apply=False, p=0.8),
+            ], p=0.6),
+    A.OneOf([
+            A.MotionBlur(blur_limit=10, p=.8),
+            A.MedianBlur(blur_limit=3, p=0.75),
+            A.GaussianBlur(blur_limit=7, p=0.75),
+            ], p=0.8),
+    A.OneOf([
+            A.RandomBrightnessContrast(
+                brightness_limit=0.3, contrast_limit=0.3, p=0.75),
+            A.RandomShadow(num_shadows_lower=1,
+                           num_shadows_upper=18, shadow_dimension=6, p=0.85),
+            ], p=0.8),
+])
+
+
+def getListOfFiles(dirName):
+    print(dirName)
+    # create a list of file and sub directories
+    # names in the given directory
+    listOfFile = os.listdir(dirName)
+    allFiles = list()
+    # Iterate over all the entries
+    for entry in listOfFile:
+        allFiles.append(entry)
+    return allFiles
+
+
+def ImageResize(image, factor=0.6):
+    width = int(image.shape[1] * factor)
+    height = int(image.shape[0] * factor)
+    dim = (width, height)
+    # print(image.shape)
+    resized = cv2.resize(image, dim, interpolation=cv2.INTER_LANCZOS4)
+    # print(resized.shape)
+    return resized
+
+
+def GetOverlappingBlocks(im, M=256, N=256, Part=8):
+    tiles = []
+    tile = np.zeros((M, N, 3), dtype=np.uint8)
+    #tile[:,:,2] = 255
+
+    x = 0
+    y = 0
+    x_start = 0
+    y_start = 0
+    while y < im.shape[0]:
+        while x < im.shape[1]:
+            if(x != 0):
+                x_start = x - int(N/Part)
+            if(y != 0):
+                y_start = y - int(M/Part)
+            if(y_start+M > im.shape[0]):
+                if(x_start+N > im.shape[1]):
+                    tile[0:im.shape[0]-y_start, 0:im.shape[1]-x_start,
+                         :] = im[y_start:im.shape[0], x_start:im.shape[1], :]
+                else:
+                    tile[0:im.shape[0]-y_start, 0:N,
+                         :] = im[y_start:im.shape[0], x_start:x_start+N, :]
+            else:
+                if(x_start+N > im.shape[1]):
+                    tile[0:M, 0:im.shape[1]-x_start,
+                         :] = im[y_start:y_start+M, x_start:im.shape[1], :]
+                else:
+                    tile[0:M, 0:N, :] = im[y_start:y_start +
+                                           M, x_start:x_start+N, :]
+
+            #pre_tile = cv2.cvtColor(PreProcessInput(cv2.cvtColor(tile, cv2.COLOR_RGB2BGR)), cv2.COLOR_BGR2RGB)
+            # tiles.append(load_tf_img(pre_tile,M))
+
+            # tiles.append(load_tf_img(tile,M))
+            tiles.append(tile)
+
+            tile = np.zeros((M, N, 3), dtype=np.uint8)
+            #tile[:,:,2] = 255
+            x = x_start + N
+        y = y_start + M
+        x = 0
+        x_start = 0
+    return tiles
+
+
+def GenerateTrainingBlocks(data_folder, gt_folder, dataset_path='./dataset', M=256, N=256):
+    print(data_folder)
+    print('Generating training blocks!!!')
+    train_path = dataset_path + '/' + data_folder + '_Trainblocks'
+
+    if not os.path.exists(train_path):
+        os.makedirs(train_path)
+
+    train_filenames = train_path + '/train_block_names.txt'
+    f = open(train_filenames, 'w')
+
+    data_path = data_folder
+    gt_path = gt_folder
+    # data_path = dataset_path + '/' + data_folder
+    # gt_path = dataset_path + '/' + gt_folder
+
+    print(data_path)
+
+    filenames = getListOfFiles(data_path)
+    cnt = 0
+    print(filenames)
+    for name in tqdm(filenames):
+        print(name)
+        if name == '.DS_Store':
+            continue
+        arr = name.split(".")
+        gt_filename = gt_path + '/' + arr[0] + "_mask."+arr[1]
+        in_filename = data_path + '/' + name
+        print(gt_filename)
+        print(in_filename)
+        gt_image_initial = cv2.imread(gt_filename)
+        in_image_initial = cv2.imread(in_filename)
+        if gt_image_initial.shape[0] + gt_image_initial.shape[1] > in_image_initial.shape[0]+in_image_initial.shape[1]:
+            gt_image_initial = cv2.resize(gt_image_initial, (in_image_initial.shape[1], in_image_initial.shape[0]))
+        else:
+            in_image_initial = cv2.resize(in_image_initial, (gt_image_initial.shape[1], gt_image_initial.shape[0]))
+        print(gt_image_initial.shape, in_image_initial.shape)
+        # cv2.imshow("img", in_image_initial)
+        # cv2.imshow("gt", gt_image_initial)
+        # cv2.waitKey(0)
+        # cv2.destroyAllWindows()
+        for scale in [0.7, 1.0, 1.4]:
+            gt_image = ImageResize(gt_image_initial, scale)
+            in_image = ImageResize(in_image_initial, scale)
+            h, w, c = in_image.shape
+            gt_img = GetOverlappingBlocks(gt_image, Part=8)
+            in_img = GetOverlappingBlocks(in_image, Part=8)
+            for i in range(len(gt_img)):
+                train_img_path = train_path + '/block_' + str(cnt) + '.png'
+                gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
+                cv2.imwrite(train_img_path, in_img[i])
+                # cv2.imwrite(train_img_path,PreProcessInput(in_img[i]))
+                cv2.imwrite(gt_img_path, gt_img[i])
+                t_name = 'block_' + str(cnt) + '.png'
+                f.write(t_name)
+                f.write('\n')
+                cnt += 1
+            Random_Block_Number_PerImage = int(len(gt_img)/5)
+            for i in range(Random_Block_Number_PerImage):
+
+                if(in_image.shape[0]-M > 1 and in_image.shape[1]-N > 1):
+                    y = random.randint(1, in_image.shape[0]-M)
+                    x = random.randint(1, in_image.shape[1]-N)
+                    in_part_img = in_image[y:y+M, x:x+N, :].copy()
+                    gt_part_img = gt_image[y:y+M, x:x+N, :].copy()
+                    train_img_path = train_path + '/block_' + str(cnt) + '.png'
+                    gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
+                    in_part_img = cv2.cvtColor(in_part_img, cv2.COLOR_BGR2RGB)
+                    augmented_image = transform(image=in_part_img)['image']
+                    augmented_image = cv2.cvtColor(
+                        augmented_image, cv2.COLOR_RGB2BGR)
+
+                    cv2.imwrite(train_img_path, augmented_image)
+                    cv2.imwrite(gt_img_path, gt_part_img)
+                    t_name = 'block_' + str(cnt) + '.png'
+                    f.write(t_name)
+                    f.write('\n')
+                    cnt += 1
+                else:
+                    break
+                    in_part_img = np.zeros((M, N, 3), dtype=np.uint8)
+                    gt_part_img = np.zeros((M, N, 3), dtype=np.uint8)
+                    in_part_img[:, :, :] = 255
+                    gt_part_img[:, :, :] = 255
+
+                    if(in_image.shape[0]-M <= 1 and in_image.shape[1]-N > 1):
+                        y = 0
+                        x = random.randint(1, in_image.shape[1]-N)
+                        in_part_img[:h, :, :] = in_image[:, x:x+N, :].copy()
+                        gt_part_img[:h, :, :] = gt_image[:, x:x+N, :].copy()
+                    if(in_image.shape[0]-M > 1 and in_image.shape[1]-N <= 1):
+                        x = 0
+                        y = random.randint(1, in_image.shape[0]-M)
+                        in_part_img[:, :w, :] = in_image[y:y+M, :, :].copy()
+                        gt_part_img[:, :w, :] = gt_image[y:y+M, :, :].copy()
+
+                    train_img_path = train_path + '/block_' + str(cnt) + '.png'
+                    gt_img_path = train_path + '/gtblock_' + str(cnt) + '.png'
+                    in_part_img = cv2.cvtColor(in_part_img, cv2.COLOR_BGR2RGB)
+                    augmented_image = transform(image=in_part_img)['image']
+                    augmented_image = cv2.cvtColor(
+                        augmented_image, cv2.COLOR_RGB2BGR)
+
+                    cv2.imwrite(train_img_path, augmented_image)
+                    cv2.imwrite(gt_img_path, gt_part_img)
+                    t_name = 'block_' + str(cnt) + '.png'
+                    f.write(t_name)
+                    f.write('\n')
+                    cnt += 1
+        # print(cnt)
+
+    f.close()
+
+    print('Total number of training blocks generated: ', cnt)
+
+    return train_path, train_filenames
+
+def CombineToImage(imgs,h,w,ch,Part=8):
+    Image = np.zeros((h,w,ch),dtype=np.float32)
+    Image_flag = np.zeros((h,w),dtype=bool)
+    i = 0
+    j = 0
+    i_end = 0
+    j_end = 0
+    for k in range(len(imgs)):
+        #part_img = np.copy(imgs[k,:,:,:])
+        part_img = np.copy(imgs[k])
+        hh,ww,cc = part_img.shape
+        i_end = min(h,i + hh)
+        j_end = min(w,j + ww)
+        
+        
+        for m in range(hh):
+            for n in range(ww):
+                if(i+m<h):
+                    if(j+n<w):
+                        if(Image_flag[i+m,j+n]):
+                            
+                            Image[i+m,j+n,:] = (Image[i+m,j+n,:] + part_img[m,n,:])/2
+                        else:
+                            Image[i+m,j+n,:] = np.copy(part_img[m,n,:])
+
+        Image_flag[i:i_end,j:j_end] = True
+        j =  min(w-1,j + ww - int(ww/Part))
+        #print(i,j,w)
+        #print(k,len(imgs))
+        if(j_end==w):
+            j = 0
+            i = min(h-1,i + hh - int(hh/Part))
+    Image = Image*255.0
+    return Image.astype(np.uint8)
+if __name__ == "__main__":
+    # img = cv2.imread("raw_data/gt/189.jpg")
+    args = parser.parse_args()
+    data_folder = f"{args.data_dir}/imgs"
+    gt_folder = f"{args.data_dir}/gt"
+    dataset = "dataset"
+    shutil.rmtree(dataset, ignore_errors=True)
+    os.mkdir(dataset)
+    train_path, train_filenames = GenerateTrainingBlocks(
+        data_folder=data_folder, gt_folder=gt_folder, dataset_path=dataset)

+ 114 - 0
doc_clean_up/infer.py

@@ -0,0 +1,114 @@
+from dataset import DocCleanDataset
+from model import M64ColorNet
+import torch
+import argparse
+import glob
+from torchvision import transforms
+from PIL import Image
+from skimage import io
+from skimage.filters.rank import mean_bilateral
+from skimage import morphology
+import cv2
+import numpy as np
+import torchvision.transforms as T
+import os
+import shutil
+from generate_dataset import GetOverlappingBlocks, CombineToImage
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--ckpt_path',
+                    type=str,
+                    help='This is the path where to store the ckpt file',
+                    default="output/model.pt")
+parser.add_argument('--img_dir',
+                    type=str,
+                    help='This is a folder where to store images to infer',
+                    default="infer_imgs")
+
+def padCropImg(img):
+    
+    H = img.shape[0]
+    W = img.shape[1]
+
+    patchRes = 128
+    pH = patchRes
+    pW = patchRes
+    ovlp = int(patchRes * 0.125)
+
+    padH = (int((H - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - H
+    padW = (int((W - patchRes)/(patchRes - ovlp) + 1) * (patchRes - ovlp) + patchRes) - W
+
+    padImg = cv2.copyMakeBorder(img, 0, padH, 0, padW, cv2.BORDER_REPLICATE)
+
+    ynum = int((padImg.shape[0] - pH)/(pH - ovlp)) + 1
+    xnum = int((padImg.shape[1] - pW)/(pW - ovlp)) + 1
+
+    totalPatch = np.zeros((ynum, xnum, patchRes, patchRes, 3), dtype=np.uint8)
+
+    for j in range(0, ynum):
+        for i in range(0, xnum):
+
+            x = int(i * (pW - ovlp))
+            y = int(j * (pH - ovlp))
+
+            totalPatch[j, i] = padImg[y:int(y + patchRes), x:int(x + patchRes)]
+
+    return totalPatch
+
+def preProcess(img):
+    img[:,:,0] = mean_bilateral(img[:,:,0], morphology.disk(20), s0=10, s1=10)
+    img[:,:,1] = mean_bilateral(img[:,:,1], morphology.disk(20), s0=10, s1=10)
+    img[:,:,2] = mean_bilateral(img[:,:,2], morphology.disk(20), s0=10, s1=10)
+    
+    return img
+
+def infer(model, img_path, output, transform, device, block_size=(256,256)):
+    out_name = os.path.basename(img_path).split(".")[0]
+    in_clr = cv2.imread(img_path,1)
+    start_time = cv2.getTickCount()
+    M = block_size[0]
+    N = block_size[1]
+    rgb_img = cv2.cvtColor(in_clr, cv2.COLOR_BGR2RGB)
+    part = 8
+    patches = GetOverlappingBlocks(rgb_img.copy(),M,N,part)
+    preds = []
+    with torch.no_grad():
+        for idx, patch in enumerate(patches):
+            input = transform(patch).to(device)
+            pred = model(input.unsqueeze(0))
+            pred = pred.cpu().detach().numpy()[0].transpose(1,2,0)
+            # cv2.imwrite(f"{output}/{out_name}_{idx}.png", cv2.cvtColor(patch, cv2.COLOR_BGR2RGB))
+            # cv2.imwrite(f"{output}/{out_name}_{idx}_pred.png", cv2.cvtColor(pred * 255, cv2.COLOR_BGR2RGB))
+            preds.append(pred)
+        # print(f"pred idx={idx}")
+    # h, w, c = preds[0].shape
+    h, w, c = in_clr.shape
+    image = CombineToImage(preds, h, w, c)
+    c_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR,part)
+    print(f"image_name:{os.path.basename(img_path)} doc clean time:{(cv2.getTickCount() - start_time)/ cv2.getTickFrequency()}")
+    cv2.imwrite(f"{output}/{out_name}.png", c_image)
+
+def infer_test(output:str, img_dir:str, ckpt_path:str, model_cls):
+    shutil.rmtree(output, ignore_errors=True)
+    os.makedirs(output)
+    model = model_cls()
+    ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
+    mean = ckpt_dict["mean"]
+    std = ckpt_dict["std"]
+    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
+    model.load_state_dict(ckpt_dict["model_state_dict"])
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    model.to(device)
+    model.eval()
+
+    # img_paths = glob.glob(f"{img_dir}/*.jpg") + glob.glob(f"{img_dir}/*.png") +glob.glob(f"{img_dir}/*.JPG")
+    img_paths = glob.glob(f"{img_dir}/002.JPG") + glob.glob(f"{img_dir}/21654792158_.pic.jpg")
+    for img_path in img_paths:
+        infer(model, img_path, output, transform, device)
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    ckpt_name = os.path.basename(args.ckpt_path).split(".")[0]
+    output = f"{args.img_dir}/output_{ckpt_name}"
+    infer_test(output, args.img_dir, args.ckpt_path, M64ColorNet)
+    

+ 65 - 0
doc_clean_up/loss.py

@@ -0,0 +1,65 @@
+import torch
+from MS_SSIM_L1_loss import MS_SSIM_L1_LOSS
+from vgg19 import VGG19
+
+def gram(x):
+    (bs, ch, h, w) = x.size()
+    f = x.view(bs, ch, w * h)
+    f_T = f.transpose(1, 2)
+    G = f.bmm(f_T) / (ch * h * w)
+    return G
+
+# class PixelLevelLoss(torch.nn.Module):
+#     def __init__(self, device):
+#         super(PixelLevelLoss, self).__init__()
+#         # self.l1_loss = torch.nn.L1Loss()
+#         self.l1_loss = MS_SSIM_L1_LOSS(device=device)
+
+#     def forward(self, pred, gt):
+#         # pred_yuv = rgb_to_ycbcr(pred)
+#         # gt_yuv = rgb_to_ycbcr(gt)
+#         # loss = torch.norm(torch.sub(pred_yuv, gt_yuv), p=1)
+#         # loss = l1_loss(pred_yuv, gt_yuv)
+#         loss = self.l1_loss(pred, gt)
+#         return loss 
+
+# https://zhuanlan.zhihu.com/p/92102879
+class PerceptualLoss(torch.nn.Module):
+    def __init__(self):
+        super(PerceptualLoss, self).__init__()
+        # self.l1_loss = torch.nn.L1Loss()
+        self.l1_loss = torch.nn.MSELoss()
+    
+
+    def tv_loss(self, y_hat):
+        return 0.5 * (torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]).mean() +
+                  torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]).mean())
+
+    def forward(self, y_hat, contents, style_pred_list, style_gt_list):
+        content_pred, content_gt = contents
+        _, c, h, w = content_pred.shape
+        content_loss = self.l1_loss(content_pred, content_gt) / float(c * h * w)     
+
+        style_loss = 0
+        for style_pred, style_gt in zip(style_pred_list, style_gt_list):
+            style_loss += self.l1_loss(gram(style_pred), gram(style_gt))
+
+        tv_l = self.tv_loss(y_hat)
+        return content_loss, style_loss, tv_l
+
+class DocCleanLoss(torch.nn.Module):
+
+    def __init__(self, device) -> None:
+        super(DocCleanLoss, self).__init__()
+        self.vgg19 = VGG19()
+        self.vgg19.to(device)
+        self.vgg19.eval()
+        self.pixel_level_loss = MS_SSIM_L1_LOSS(device=device)
+        self.perceptual_loss = PerceptualLoss()
+
+    def forward(self, pred_imgs, gt_imgs):
+        p_l_loss = self.pixel_level_loss(pred_imgs, gt_imgs)
+        contents, style_pred_list, style_gt_list = self.vgg19(pred_imgs, gt_imgs)
+        content_loss, style_loss, tv_l = self.perceptual_loss(pred_imgs, contents, style_pred_list, style_gt_list)
+        # return 1e1*p_l_loss + 1e-1*content_loss + 1e1*style_loss, p_l_loss, content_loss, style_loss
+        return 1e1*p_l_loss + 1e1*content_loss + 1e-1*style_loss + 1e1*tv_l, p_l_loss, content_loss, style_loss

+ 124 - 0
doc_clean_up/model.py

@@ -0,0 +1,124 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, channels):
+        super(ResidualBlock, self).__init__()
+        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(channels)
+        self.relu6_1 = nn.ReLU6(inplace=True)
+        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(channels)
+        self.relu6_2 = nn.ReLU6(inplace=True)
+        self.relu6_latest = nn.ReLU6(inplace=True)
+
+    def forward(self, x):
+        residual = self.conv1(x)
+        residual = self.bn1(residual)
+        residual = self.relu6_1(residual)
+
+        residual = self.conv2(residual)
+        residual = self.bn2(residual)
+        residual = self.relu6_2(residual)
+
+        add = x + residual
+        return self.relu6_latest(add)
+
+class M64ColorNet(nn.Module):
+
+    def __init__(self):
+        super(M64ColorNet, self).__init__()
+
+        self.block1 = nn.Sequential(
+            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False, stride=1),
+            nn.BatchNorm2d(16),
+            nn.ReLU6(inplace=True)
+        )
+        self.block2 = nn.Sequential(
+            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2),
+            nn.BatchNorm2d(32),
+            nn.ReLU6(inplace=True)
+        )
+        self.block3 = nn.Sequential(
+            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, bias=False, stride=2),
+            nn.BatchNorm2d(64),
+            nn.ReLU6(inplace=True)
+        )
+        self.block4 = ResidualBlock(64)
+        self.block5 = ResidualBlock(64)
+        self.block6 = ResidualBlock(64)
+        self.block7 = ResidualBlock(64)
+        self.block8 = ResidualBlock(64)
+        self.block9 = nn.Sequential(
+            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(64),
+            nn.ReLU6(inplace=True)
+        )
+        self.block10 = nn.Sequential(
+            # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False),
+            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1),
+            nn.BatchNorm2d(32),
+            nn.ReLU6(inplace=True)
+        )
+        self.block11 = nn.Sequential(
+            # nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False),
+            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1),
+            nn.BatchNorm2d(16),
+            nn.ReLU6(inplace=True)
+        )
+        self.block12 = nn.Sequential(
+            nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(3),
+            nn.ReLU6(inplace=True)
+        )
+        # self.dropout = nn.Dropout(0.4)
+
+    def forward(self, x):
+        input = x
+
+        x = self.block1(x)
+        input2 = x
+
+        x = self.block2(x)
+        input3 = x
+
+        x = self.block3(x)
+        input4 = x
+
+        x = self.block4(x)
+        x = self.block5(x)
+        x = self.block6(x)
+        x = self.block7(x)
+        x = self.block8(x)
+        x = input4 + x
+        
+        x = self.block9(x)
+
+        x = self.block10(x)
+        x = input3 + x
+
+        x = self.block11(x)
+        x = input2 + x
+
+        x = self.block12(x)
+        x = input + x
+
+        x = torch.sigmoid(x)
+        return x
+
+    @staticmethod    
+    def load_trained_model(ckpt_path):
+        ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
+        model = M64ColorNet()
+        model.load_state_dict(ckpt_dict["model_state_dict"])
+        model.eval()
+        return model, ckpt_dict["mean"], ckpt_dict["std"], ckpt_dict["loss"], ckpt_dict["ssim_score"], ckpt_dict["psnr_score"]
+
+
+
+
+
+
+

+ 66 - 0
doc_clean_up/requirements.txt

@@ -0,0 +1,66 @@
+absl-py==1.0.0
+albumentations==1.1.0
+autopep8==1.6.0
+cachetools==5.1.0
+certifi==2021.10.8
+charset-normalizer==2.0.12
+cycler==0.11.0
+fonttools==4.33.3
+google-auth==2.6.6
+google-auth-oauthlib==0.4.6
+grpcio==1.46.3
+idna==3.3
+imagecorruptions==1.1.2
+imageio==2.19.2
+imgaug==0.4.0
+importlib-metadata==4.11.4
+joblib==1.1.0
+kiwisolver==1.4.2
+Markdown==3.3.7
+matplotlib==3.5.2
+networkx==2.8
+numpy==1.22.3
+oauthlib==3.2.0
+opencv-python==4.5.5.64
+opencv-python-headless==4.5.5.64
+packaging==21.3
+pandas==1.4.2
+Pillow==9.1.0
+protobuf==3.20.1
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pycodestyle==2.8.0
+pyparsing==3.0.9
+python-dateutil==2.8.2
+pytz==2022.1
+PyWavelets==1.3.0
+PyYAML==6.0
+qudida==0.0.4
+requests==2.27.1
+requests-oauthlib==1.3.1
+rsa==4.8
+scikit-image==0.19.2
+scikit-learn==1.1.0
+scipy==1.8.0
+Shapely==1.8.2
+six==1.16.0
+sklearn==0.0
+tensorboard==2.9.0
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+threadpoolctl==3.1.0
+tifffile==2022.5.4
+toml==0.10.2
+# torch==1.11.0
+torch-tb-profiler==0.4.0
+torchattacks==3.2.6
+# torchaudio==0.11.0
+torchinfo==1.6.6
+torchmetrics==0.9.0
+torchsummary==1.5.1
+# torchvision==0.12.0
+tqdm==4.64.0
+typing_extensions==4.2.0
+urllib3==1.26.9
+Werkzeug==2.1.2
+zipp==3.8.0

+ 22 - 0
doc_clean_up/tflite_infer.py

@@ -0,0 +1,22 @@
+import numpy as np
+import tensorflow as tf
+
+# Load the TFLite model and allocate tensors
+interpreter = tf.lite.Interpreter(model_path="torch_script_model/doc_clean.tflite")
+interpreter.allocate_tensors()
+
+# Get input and output tensors
+input_details = interpreter.get_input_details()
+output_details = interpreter.get_output_details()
+
+# Test the model on random input data
+input_shape = input_details[0]['shape']
+input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
+interpreter.set_tensor(input_details[0]['index'], input_data)
+
+interpreter.invoke()
+
+# get_tensor() returns a copy of the tensor data
+# use tensor() in order to get a pointer to the tensor
+output_data = interpreter.get_tensor(output_details[0]['index'])
+print(output_data)

+ 295 - 0
doc_clean_up/train.py

@@ -0,0 +1,295 @@
+from infer import infer_test
+from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure
+from torchvision import transforms
+from torch.utils.tensorboard import SummaryWriter
+import argparse
+import torchvision.transforms as T
+import shutil
+import os
+from matplotlib import pyplot as plt
+from model import M64ColorNet
+from loss import DocCleanLoss
+from torch.utils.data import DataLoader
+from dataset import DocCleanDataset
+import torch
+from tqdm import tqdm
+from nni.compression.pytorch.pruning import L1NormPruner
+from nni.compression.pytorch.speedup import ModelSpeedup
+import matplotlib
+matplotlib.use('Agg')
+# from torchinfo import summary
+writer = SummaryWriter()
+
+
+def boolean_string(s):
+    ''' Check s string is true or false.
+    Args:
+        s: the string
+    Returns:
+        boolean
+    '''
+    s = s.lower()
+    if s not in {'false', 'true'}:
+        raise ValueError('Not a valid boolean string')
+    return s == 'true'
+
+
+# path parameters
+parser = argparse.ArgumentParser()
+parser.add_argument('--develop',
+                    type=boolean_string,
+                    help='Develop mode turn off by default',
+                    default=False)
+parser.add_argument('--lr',
+                    type=float,
+                    help='Develop mode turn off by default',
+                    default=1e-3)
+parser.add_argument('--batch_size',
+                    type=int,
+                    help='Develop mode turn off by default',
+                    default=16)
+parser.add_argument('--retrain',
+                    type=boolean_string,
+                    help='Whether to restore the checkpoint',
+                    default=False)
+parser.add_argument('--epochs',
+                    type=int,
+                    help='Max training epoch',
+                    default=500)
+parser.add_argument('--dataset',
+                    type=str,
+                    help='Max training epoch',
+                    default="dataset/raw_data/imgs_Trainblocks")
+parser.add_argument('--shuffle',
+                    type=boolean_string,
+                    help='Whether to shuffle dataset',
+                    default=True)
+
+
+def saveEvalImg(img_dir: str, batch_idx: int, imgs, pred_imgs, gt_imgs, normalized_imgs):
+    transform = T.ToPILImage()
+    for idx, (img, normalized_img, pred_img, gt_img) in enumerate(zip(imgs, normalized_imgs, pred_imgs, gt_imgs)):
+        img = transform(img)
+        normalized_img = transform(normalized_img)
+        pred_img = transform(pred_img)
+        gt_img = transform(gt_img)
+        f, axarr = plt.subplots(1, 4)
+        axarr[0].imshow(img)
+        axarr[0].title.set_text('orig')
+        axarr[1].imshow(normalized_img)
+        axarr[1].title.set_text('normal')
+        axarr[2].imshow(pred_img)
+        axarr[2].title.set_text('pred')
+        axarr[3].imshow(gt_img)
+        axarr[3].title.set_text('gt')
+        f.savefig(f"{img_dir}/{batch_idx:04d}_{idx}.jpg")
+        plt.close()
+
+
+def evaluator(model:torch.nn.Module, epoch:int, test_loader:DataLoader, tag:str):
+    img_dir = f"{output}/{tag}/{epoch}"
+    if os.path.exists(img_dir):
+        shutil.rmtree(img_dir, ignore_errors=True)
+    os.makedirs(img_dir)
+    valid_loss = 0
+    model.eval()
+    eval_criterion = DocCleanLoss(device)
+    with torch.no_grad():
+        ssim_score = 0
+        psnr_score = 0
+        for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(test_loader)):
+            imgs = imgs.to(device)
+            gt_imgs = gt_imgs.to(device)
+            normalized_imgs = normalized_imgs.to(device)
+            pred_imgs = model(normalized_imgs)
+            ssim_score += structural_similarity_index_measure(
+                pred_imgs, gt_imgs).item()
+            psnr_score += peak_signal_noise_ratio(pred_imgs, gt_imgs).item()
+            loss, _, _, _ = eval_criterion(pred_imgs, gt_imgs)
+            valid_loss += loss.item()
+            if index % 30 == 0:
+                saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
+                            pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
+        data_len = len(test_loader)
+        valid_loss = valid_loss / data_len
+        psnr_score = psnr_score / data_len
+        ssim_score = ssim_score / data_len
+    return valid_loss, psnr_score, ssim_score
+
+
+def batch_mean_std(loader):
+    nb_samples = 0.
+    channel_mean = torch.zeros(3)
+    channel_std = torch.zeros(3)
+    for images, _, _ in tqdm(loader):
+        # scale image to be between 0 and 1
+        N, C, H, W = images.shape[:4]
+        data = images.view(N, C, -1)
+
+        channel_mean += data.mean(2).sum(0)
+        channel_std += data.std(2).sum(0)
+        nb_samples += N
+
+    channel_mean /= nb_samples
+    channel_std /= nb_samples
+    return channel_mean, channel_std
+
+
+def saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score):
+    torch.save({
+        'epoch': epoch,
+        'model_state_dict': model.state_dict(),
+        'optimizer_state_dict': optimizer.state_dict(),
+        'scheduler_state_dict': scheduler.state_dict(),
+        'loss': validation_loss,
+        'mean': mean,
+        'std': std,
+        'psnr_score': psnr_score,
+        'ssim_score': ssim_score
+    }, model_path)
+
+def trainer(model:torch.nn.Module, criterion:DocCleanLoss, optimizer:torch.optim.Adam, tag:str, epoch:int):
+    # train
+    model.train()
+    running_loss = 0
+    running_content_loss = 0
+    running_style_loss = 0
+    running_pixel_loss = 0
+    img_dir = f"{output}/{tag}/{epoch}"
+    if os.path.exists(img_dir):
+        shutil.rmtree(img_dir, ignore_errors=True)
+    os.makedirs(img_dir)
+    for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(train_loader)):
+        optimizer.zero_grad()
+        imgs = imgs.to(device)
+        gt_imgs = gt_imgs.to(device)
+        normalized_imgs = normalized_imgs.to(device)
+        pred_imgs = model(normalized_imgs)
+        loss, p_l_loss, content_loss, style_loss = criterion(
+            pred_imgs, gt_imgs)
+        loss.backward()
+        optimizer.step()
+        running_loss += loss.item()
+        running_pixel_loss += p_l_loss.item()
+        running_content_loss += content_loss.item()
+        running_style_loss += style_loss.item()
+        if index % 200 == 0:
+            saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
+                        pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
+    return running_loss, running_pixel_loss, running_content_loss, running_style_loss
+
+def model_pruning():
+    model, mean, std = M64ColorNet.load_trained_model("output/model.pt")
+    model.to(device)
+    # Compress this model.
+    config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
+    pruner = L1NormPruner(model, config_list)
+    _, masks = pruner.compress()
+
+    print('\nThe accuracy with masks:')
+    evaluator(model, 0, test_loader, "masks")
+
+    pruner._unwrap_model()
+    ModelSpeedup(model, dummy_input=torch.rand(1, 3, 256, 256).to(device), masks_file=masks).speedup_model()
+
+    print('\nThe accuracy after speedup:')
+    evaluator(model, 0, test_loader, "speedup")
+
+    # Need a new optimizer due to the modules in model will be replaced during speedup.
+    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
+    criterion = DocCleanLoss(device=device)
+    print('\nFinetune the model after speedup:')
+    for i in range(5):
+        trainer(model, criterion, optimizer, "train_finetune", i)
+        evaluator(model, i, test_loader, "eval_finetune")
+
+def pretrain():
+    print(f"device={device} \
+            develop={args.develop} \
+            lr={args.lr} \
+            mean={mean} \
+            std={std} \
+            shuffle={args.shuffle}")
+    model_cls = M64ColorNet
+    model = model_cls()
+    model.to(device)
+    # summary(model, input_size=(batch_size, 3, 256, 256))
+    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
+    scheduler = torch.optim.lr_scheduler.StepLR(
+        optimizer, step_size=15, gamma=0.8)
+
+    model_path = f"{output}/model.pt"
+    current_epoch = 1
+    previous_loss = float('inf')
+    criterion = DocCleanLoss(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
+        current_epoch = checkpoint['epoch'] + 1
+        previous_loss = checkpoint['loss']
+    for epoch in range(current_epoch, current_epoch+args.epochs):
+        running_loss, running_pixel_loss, running_content_loss, running_style_loss = trainer(model, criterion, optimizer, "train", epoch)
+        train_loss = running_loss / len(train_loader)
+        train_content_loss = running_content_loss / len(train_loader)
+        train_style_loss = running_style_loss / len(train_loader)
+        train_pixel_loss = running_pixel_loss / len(train_loader)
+        # evaluate
+        validation_loss, psnr_score, ssim_score = evaluator(model, epoch, test_loader, "eval")
+        writer.add_scalar("Loss/train", train_loss, epoch)
+        writer.add_scalar("Loss/validation", validation_loss, epoch)
+        writer.add_scalar("metric/psnr", psnr_score, epoch)
+        writer.add_scalar("metric/ssim", ssim_score, epoch)
+        if previous_loss > validation_loss:
+            # This model_path is used for resume training. Hold the latest ckpt.
+            saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
+            # This for each epoch ckpt.
+            saveCkpt(model, f"{output}/model_{epoch}.pt", epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
+            infer_test(f"{output}/infer_test/{epoch}",
+                       "infer_imgs", model_path, model_cls)
+            previous_loss = validation_loss
+        scheduler.step()
+        print(
+            f"epoch:{epoch} \
+            train_loss:{round(train_loss, 4)} \
+            validation_loss:{round(validation_loss, 4)} \
+            pixel_loss:{round(train_pixel_loss, 4)} \
+            content_loss:{round(train_content_loss, 8)} \
+            style_loss:{round(train_style_loss, 4)} \
+            lr:{round(optimizer.param_groups[0]['lr'], 5)} \
+            psnr:{round(psnr_score, 3)} \
+            ssim:{round(ssim_score, 3)}"
+        )
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    train_img_names, eval_img_names, imgs_dir = DocCleanDataset.prepareDataset(args.dataset, args.shuffle)
+    output = "output"
+    if args.retrain == True:
+        shutil.rmtree(output, ignore_errors=True)
+    if os.path.exists(output) == False:
+        os.mkdir(output)
+    print(
+        f"trainset num:{len(train_img_names)}\nevalset num:{len(eval_img_names)}")
+
+    dataset = DocCleanDataset(
+        img_names=train_img_names, imgs_dir=imgs_dir, dev=args.develop)
+    mean, std = batch_mean_std(DataLoader(
+        dataset=dataset, batch_size=args.batch_size))
+    # mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+    # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
+    train_set = DocCleanDataset(
+        img_names=train_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop, img_aug=True)
+    test_set = DocCleanDataset(
+        img_names=eval_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop)
+    train_loader = DataLoader(
+        dataset=train_set, batch_size=args.batch_size, shuffle=args.shuffle)
+    test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size)
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+
+    pretrain()
+
+    # model_pruning()
+
+    writer.flush()

+ 69 - 0
doc_clean_up/vgg19.py

@@ -0,0 +1,69 @@
+import torchvision.models as models
+from torchsummary import summary
+import torch
+
+class VGG19(torch.nn.Module):
+    def __init__(self):
+        super(VGG19, self).__init__()
+        vgg_net = models.vgg19(pretrained=True)
+        # summary(vgg_net, (3, 224, 224))
+        features = vgg_net.features
+        self.relu_1_1 = torch.nn.Sequential()
+        self.relu_1_2 = torch.nn.Sequential()
+        self.relu_2_1 = torch.nn.Sequential()
+        self.relu_3_1 = torch.nn.Sequential()
+        self.relu_4_1 = torch.nn.Sequential()
+        self.relu_5_1 = torch.nn.Sequential()
+
+        for x in range(0, 2):
+            self.relu_1_1.add_module(str(x), features[x])
+        for x in range(2, 4):
+            self.relu_1_2.add_module(str(x), features[x])
+        for x in range(4, 7):
+            self.relu_2_1.add_module(str(x), features[x])
+        for x in range(7, 12):
+            self.relu_3_1.add_module(str(x), features[x])
+        for x in range(12, 21):
+            self.relu_4_1.add_module(str(x), features[x])
+        for x in range(21, 30):
+            self.relu_5_1.add_module(str(x), features[x])
+        
+        # don't need the gradients, just want the features
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, pred, gt):
+        h_pred = self.relu_1_1(pred)
+        h_gt = self.relu_1_1(gt)
+        style_pred_1 = h_pred
+        style_gt_1 = h_pred
+
+        h_pred = self.relu_1_2(h_pred)        
+        h_gt = self.relu_1_2(h_gt)
+        content_pred = h_pred
+        content_gt = h_gt
+
+        h_pred = self.relu_2_1(h_pred)        
+        h_gt = self.relu_2_1(h_gt)
+        style_pred_2 = h_pred
+        style_gt_2 = h_gt
+
+        h_pred = self.relu_3_1(h_pred)        
+        h_gt = self.relu_3_1(h_gt)
+        style_pred_3 = h_pred
+        style_gt_3 = h_gt
+
+        h_pred = self.relu_4_1(h_pred)        
+        h_gt = self.relu_4_1(h_gt)
+        style_pred_4 = h_pred
+        style_gt_4 = h_gt
+
+        h_pred = self.relu_5_1(h_pred)        
+        h_gt = self.relu_5_1(h_gt)
+        style_pred_5 = h_pred
+        style_gt_5 = h_gt
+
+        contents = (content_pred, content_gt)
+        style_gt_list = [style_gt_1, style_gt_2, style_gt_3, style_gt_4, style_gt_5]
+        style_pred_list = [style_pred_1, style_pred_2, style_pred_3, style_pred_4, style_pred_5]
+        return contents, style_pred_list, style_gt_list

二進制
document/交接文档.docx