warp_mls.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
  17. """
  18. import numpy as np
  19. class WarpMLS:
  20. def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
  21. self.src = src
  22. self.src_pts = src_pts
  23. self.dst_pts = dst_pts
  24. self.pt_count = len(self.dst_pts)
  25. self.dst_w = dst_w
  26. self.dst_h = dst_h
  27. self.trans_ratio = trans_ratio
  28. self.grid_size = 100
  29. self.rdx = np.zeros((self.dst_h, self.dst_w))
  30. self.rdy = np.zeros((self.dst_h, self.dst_w))
  31. @staticmethod
  32. def __bilinear_interp(x, y, v11, v12, v21, v22):
  33. return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
  34. (1 - y) + v22 * y) * x
  35. def generate(self):
  36. self.calc_delta()
  37. return self.gen_img()
  38. def calc_delta(self):
  39. w = np.zeros(self.pt_count, dtype=np.float32)
  40. if self.pt_count < 2:
  41. return
  42. i = 0
  43. while 1:
  44. if self.dst_w <= i < self.dst_w + self.grid_size - 1:
  45. i = self.dst_w - 1
  46. elif i >= self.dst_w:
  47. break
  48. j = 0
  49. while 1:
  50. if self.dst_h <= j < self.dst_h + self.grid_size - 1:
  51. j = self.dst_h - 1
  52. elif j >= self.dst_h:
  53. break
  54. sw = 0
  55. swp = np.zeros(2, dtype=np.float32)
  56. swq = np.zeros(2, dtype=np.float32)
  57. new_pt = np.zeros(2, dtype=np.float32)
  58. cur_pt = np.array([i, j], dtype=np.float32)
  59. k = 0
  60. for k in range(self.pt_count):
  61. if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
  62. break
  63. w[k] = 1. / (
  64. (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
  65. (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
  66. sw += w[k]
  67. swp = swp + w[k] * np.array(self.dst_pts[k])
  68. swq = swq + w[k] * np.array(self.src_pts[k])
  69. if k == self.pt_count - 1:
  70. pstar = 1 / sw * swp
  71. qstar = 1 / sw * swq
  72. miu_s = 0
  73. for k in range(self.pt_count):
  74. if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
  75. continue
  76. pt_i = self.dst_pts[k] - pstar
  77. miu_s += w[k] * np.sum(pt_i * pt_i)
  78. cur_pt -= pstar
  79. cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
  80. for k in range(self.pt_count):
  81. if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
  82. continue
  83. pt_i = self.dst_pts[k] - pstar
  84. pt_j = np.array([-pt_i[1], pt_i[0]])
  85. tmp_pt = np.zeros(2, dtype=np.float32)
  86. tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
  87. np.sum(pt_j * cur_pt) * self.src_pts[k][1]
  88. tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
  89. np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
  90. tmp_pt *= (w[k] / miu_s)
  91. new_pt += tmp_pt
  92. new_pt += qstar
  93. else:
  94. new_pt = self.src_pts[k]
  95. self.rdx[j, i] = new_pt[0] - i
  96. self.rdy[j, i] = new_pt[1] - j
  97. j += self.grid_size
  98. i += self.grid_size
  99. def gen_img(self):
  100. src_h, src_w = self.src.shape[:2]
  101. dst = np.zeros_like(self.src, dtype=np.float32)
  102. for i in np.arange(0, self.dst_h, self.grid_size):
  103. for j in np.arange(0, self.dst_w, self.grid_size):
  104. ni = i + self.grid_size
  105. nj = j + self.grid_size
  106. w = h = self.grid_size
  107. if ni >= self.dst_h:
  108. ni = self.dst_h - 1
  109. h = ni - i + 1
  110. if nj >= self.dst_w:
  111. nj = self.dst_w - 1
  112. w = nj - j + 1
  113. di = np.reshape(np.arange(h), (-1, 1))
  114. dj = np.reshape(np.arange(w), (1, -1))
  115. delta_x = self.__bilinear_interp(
  116. di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
  117. self.rdx[ni, j], self.rdx[ni, nj])
  118. delta_y = self.__bilinear_interp(
  119. di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
  120. self.rdy[ni, j], self.rdy[ni, nj])
  121. nx = j + dj + delta_x * self.trans_ratio
  122. ny = i + di + delta_y * self.trans_ratio
  123. nx = np.clip(nx, 0, src_w - 1)
  124. ny = np.clip(ny, 0, src_h - 1)
  125. nxi = np.array(np.floor(nx), dtype=np.int32)
  126. nyi = np.array(np.floor(ny), dtype=np.int32)
  127. nxi1 = np.array(np.ceil(nx), dtype=np.int32)
  128. nyi1 = np.array(np.ceil(ny), dtype=np.int32)
  129. if len(self.src.shape) == 3:
  130. x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
  131. y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
  132. else:
  133. x = ny - nyi
  134. y = nx - nxi
  135. dst[i:i + h, j:j + w] = self.__bilinear_interp(
  136. x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
  137. self.src[nyi1, nxi], self.src[nyi1, nxi1])
  138. dst = np.clip(dst, 0, 255)
  139. dst = np.array(dst, dtype=np.uint8)
  140. return dst