keypoint_operators.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # function:
  15. # operators to process sample,
  16. # eg: decode/resize/crop image
  17. from __future__ import absolute_import
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. import cv2
  23. import numpy as np
  24. import math
  25. import copy
  26. from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform, get_warp_matrix
  27. from ppdet.core.workspace import serializable
  28. from ppdet.utils.logger import setup_logger
  29. logger = setup_logger(__name__)
  30. registered_ops = []
  31. __all__ = [
  32. 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
  33. 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
  34. 'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK',
  35. 'ToHeatmapsTopDown_UDP', 'TopDownEvalAffine',
  36. 'AugmentationbyInformantionDropping', 'SinglePoseAffine', 'NoiseJitter',
  37. 'FlipPose'
  38. ]
  39. def register_keypointop(cls):
  40. return serializable(cls)
  41. @register_keypointop
  42. class KeyPointFlip(object):
  43. """Get the fliped image by flip_prob. flip the coords also
  44. the left coords and right coords should exchange while flip, for the right keypoint will be left keypoint after image fliped
  45. Args:
  46. flip_permutation (list[17]): the left-right exchange order list corresponding to [0,1,2,...,16]
  47. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  48. flip_prob (float): the ratio whether to flip the image
  49. records(dict): the dict contained the image, mask and coords
  50. Returns:
  51. records(dict): contain the image, mask and coords after tranformed
  52. """
  53. def __init__(self, flip_permutation, hmsize, flip_prob=0.5):
  54. super(KeyPointFlip, self).__init__()
  55. assert isinstance(flip_permutation, Sequence)
  56. self.flip_permutation = flip_permutation
  57. self.flip_prob = flip_prob
  58. self.hmsize = hmsize
  59. def __call__(self, records):
  60. image = records['image']
  61. kpts_lst = records['joints']
  62. mask_lst = records['mask']
  63. flip = np.random.random() < self.flip_prob
  64. if flip:
  65. image = image[:, ::-1]
  66. for idx, hmsize in enumerate(self.hmsize):
  67. if len(mask_lst) > idx:
  68. mask_lst[idx] = mask_lst[idx][:, ::-1]
  69. if kpts_lst[idx].ndim == 3:
  70. kpts_lst[idx] = kpts_lst[idx][:, self.flip_permutation]
  71. else:
  72. kpts_lst[idx] = kpts_lst[idx][self.flip_permutation]
  73. kpts_lst[idx][..., 0] = hmsize - kpts_lst[idx][..., 0]
  74. kpts_lst[idx] = kpts_lst[idx].astype(np.int64)
  75. kpts_lst[idx][kpts_lst[idx][..., 0] >= hmsize, 2] = 0
  76. kpts_lst[idx][kpts_lst[idx][..., 1] >= hmsize, 2] = 0
  77. kpts_lst[idx][kpts_lst[idx][..., 0] < 0, 2] = 0
  78. kpts_lst[idx][kpts_lst[idx][..., 1] < 0, 2] = 0
  79. records['image'] = image
  80. records['joints'] = kpts_lst
  81. records['mask'] = mask_lst
  82. return records
  83. @register_keypointop
  84. class RandomAffine(object):
  85. """apply affine transform to image, mask and coords
  86. to achieve the rotate, scale and shift effect for training image
  87. Args:
  88. max_degree (float): the max abslute rotate degree to apply, transform range is [-max_degree, max_degree]
  89. max_scale (list[2]): the scale range to apply, transform range is [min, max]
  90. max_shift (float): the max abslute shift ratio to apply, transform range is [-max_shift*imagesize, max_shift*imagesize]
  91. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  92. trainsize (int): the standard length used to train, the 'scale_type' of [h,w] will be resize to trainsize for standard
  93. scale_type (str): the length of [h,w] to used for trainsize, chosed between 'short' and 'long'
  94. records(dict): the dict contained the image, mask and coords
  95. Returns:
  96. records(dict): contain the image, mask and coords after tranformed
  97. """
  98. def __init__(self,
  99. max_degree=30,
  100. scale=[0.75, 1.5],
  101. max_shift=0.2,
  102. hmsize=[128, 256],
  103. trainsize=512,
  104. scale_type='short'):
  105. super(RandomAffine, self).__init__()
  106. self.max_degree = max_degree
  107. self.min_scale = scale[0]
  108. self.max_scale = scale[1]
  109. self.max_shift = max_shift
  110. self.hmsize = hmsize
  111. self.trainsize = trainsize
  112. self.scale_type = scale_type
  113. def _get_affine_matrix(self, center, scale, res, rot=0):
  114. """Generate transformation matrix."""
  115. h = scale
  116. t = np.zeros((3, 3), dtype=np.float32)
  117. t[0, 0] = float(res[1]) / h
  118. t[1, 1] = float(res[0]) / h
  119. t[0, 2] = res[1] * (-float(center[0]) / h + .5)
  120. t[1, 2] = res[0] * (-float(center[1]) / h + .5)
  121. t[2, 2] = 1
  122. if rot != 0:
  123. rot = -rot # To match direction of rotation from cropping
  124. rot_mat = np.zeros((3, 3), dtype=np.float32)
  125. rot_rad = rot * np.pi / 180
  126. sn, cs = np.sin(rot_rad), np.cos(rot_rad)
  127. rot_mat[0, :2] = [cs, -sn]
  128. rot_mat[1, :2] = [sn, cs]
  129. rot_mat[2, 2] = 1
  130. # Need to rotate around center
  131. t_mat = np.eye(3)
  132. t_mat[0, 2] = -res[1] / 2
  133. t_mat[1, 2] = -res[0] / 2
  134. t_inv = t_mat.copy()
  135. t_inv[:2, 2] *= -1
  136. t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
  137. return t
  138. def __call__(self, records):
  139. image = records['image']
  140. keypoints = records['joints']
  141. heatmap_mask = records['mask']
  142. degree = (np.random.random() * 2 - 1) * self.max_degree
  143. shape = np.array(image.shape[:2][::-1])
  144. center = center = np.array((np.array(shape) / 2))
  145. aug_scale = np.random.random() * (self.max_scale - self.min_scale
  146. ) + self.min_scale
  147. if self.scale_type == 'long':
  148. scale = max(shape[0], shape[1]) / 1.0
  149. elif self.scale_type == 'short':
  150. scale = min(shape[0], shape[1]) / 1.0
  151. else:
  152. raise ValueError('Unknown scale type: {}'.format(self.scale_type))
  153. roi_size = aug_scale * scale
  154. dx = int(0)
  155. dy = int(0)
  156. if self.max_shift > 0:
  157. dx = np.random.randint(-self.max_shift * roi_size,
  158. self.max_shift * roi_size)
  159. dy = np.random.randint(-self.max_shift * roi_size,
  160. self.max_shift * roi_size)
  161. center += np.array([dx, dy])
  162. input_size = 2 * center
  163. keypoints[..., :2] *= shape
  164. heatmap_mask *= 255
  165. kpts_lst = []
  166. mask_lst = []
  167. image_affine_mat = self._get_affine_matrix(
  168. center, roi_size, (self.trainsize, self.trainsize), degree)[:2]
  169. image = cv2.warpAffine(
  170. image,
  171. image_affine_mat, (self.trainsize, self.trainsize),
  172. flags=cv2.INTER_LINEAR)
  173. for hmsize in self.hmsize:
  174. kpts = copy.deepcopy(keypoints)
  175. mask_affine_mat = self._get_affine_matrix(
  176. center, roi_size, (hmsize, hmsize), degree)[:2]
  177. if heatmap_mask is not None:
  178. mask = cv2.warpAffine(heatmap_mask, mask_affine_mat,
  179. (hmsize, hmsize))
  180. mask = ((mask / 255) > 0.5).astype(np.float32)
  181. kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(),
  182. mask_affine_mat)
  183. kpts[np.trunc(kpts[..., 0]) >= hmsize, 2] = 0
  184. kpts[np.trunc(kpts[..., 1]) >= hmsize, 2] = 0
  185. kpts[np.trunc(kpts[..., 0]) < 0, 2] = 0
  186. kpts[np.trunc(kpts[..., 1]) < 0, 2] = 0
  187. kpts_lst.append(kpts)
  188. mask_lst.append(mask)
  189. records['image'] = image
  190. records['joints'] = kpts_lst
  191. records['mask'] = mask_lst
  192. return records
  193. @register_keypointop
  194. class EvalAffine(object):
  195. """apply affine transform to image
  196. resize the short of [h,w] to standard size for eval
  197. Args:
  198. size (int): the standard length used to train, the 'short' of [h,w] will be resize to trainsize for standard
  199. records(dict): the dict contained the image, mask and coords
  200. Returns:
  201. records(dict): contain the image, mask and coords after tranformed
  202. """
  203. def __init__(self, size, stride=64):
  204. super(EvalAffine, self).__init__()
  205. self.size = size
  206. self.stride = stride
  207. def __call__(self, records):
  208. image = records['image']
  209. mask = records['mask'] if 'mask' in records else None
  210. s = self.size
  211. h, w, _ = image.shape
  212. trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
  213. image_resized = cv2.warpAffine(image, trans, size_resized)
  214. if mask is not None:
  215. mask = cv2.warpAffine(mask, trans, size_resized)
  216. records['mask'] = mask
  217. if 'joints' in records:
  218. del records['joints']
  219. records['image'] = image_resized
  220. return records
  221. @register_keypointop
  222. class NormalizePermute(object):
  223. def __init__(self,
  224. mean=[123.675, 116.28, 103.53],
  225. std=[58.395, 57.120, 57.375],
  226. is_scale=True):
  227. super(NormalizePermute, self).__init__()
  228. self.mean = mean
  229. self.std = std
  230. self.is_scale = is_scale
  231. def __call__(self, records):
  232. image = records['image']
  233. image = image.astype(np.float32)
  234. if self.is_scale:
  235. image /= 255.
  236. image = image.transpose((2, 0, 1))
  237. mean = np.array(self.mean, dtype=np.float32)
  238. std = np.array(self.std, dtype=np.float32)
  239. invstd = 1. / std
  240. for v, m, s in zip(image, mean, invstd):
  241. v.__isub__(m).__imul__(s)
  242. records['image'] = image
  243. return records
  244. @register_keypointop
  245. class TagGenerate(object):
  246. """record gt coords for aeloss to sample coords value in tagmaps
  247. Args:
  248. num_joints (int): the keypoint numbers of dataset to train
  249. num_people (int): maxmum people to support for sample aeloss
  250. records(dict): the dict contained the image, mask and coords
  251. Returns:
  252. records(dict): contain the gt coords used in tagmap
  253. """
  254. def __init__(self, num_joints, max_people=30):
  255. super(TagGenerate, self).__init__()
  256. self.max_people = max_people
  257. self.num_joints = num_joints
  258. def __call__(self, records):
  259. kpts_lst = records['joints']
  260. kpts = kpts_lst[0]
  261. tagmap = np.zeros((self.max_people, self.num_joints, 4), dtype=np.int64)
  262. inds = np.where(kpts[..., 2] > 0)
  263. p, j = inds[0], inds[1]
  264. visible = kpts[inds]
  265. # tagmap is [p, j, 3], where last dim is j, y, x
  266. tagmap[p, j, 0] = j
  267. tagmap[p, j, 1] = visible[..., 1] # y
  268. tagmap[p, j, 2] = visible[..., 0] # x
  269. tagmap[p, j, 3] = 1
  270. records['tagmap'] = tagmap
  271. del records['joints']
  272. return records
  273. @register_keypointop
  274. class ToHeatmaps(object):
  275. """to generate the gaussin heatmaps of keypoint for heatmap loss
  276. Args:
  277. num_joints (int): the keypoint numbers of dataset to train
  278. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  279. sigma (float): the std of gaussin kernel genereted
  280. records(dict): the dict contained the image, mask and coords
  281. Returns:
  282. records(dict): contain the heatmaps used to heatmaploss
  283. """
  284. def __init__(self, num_joints, hmsize, sigma=None):
  285. super(ToHeatmaps, self).__init__()
  286. self.num_joints = num_joints
  287. self.hmsize = np.array(hmsize)
  288. if sigma is None:
  289. sigma = hmsize[0] // 64
  290. self.sigma = sigma
  291. r = 6 * sigma + 3
  292. x = np.arange(0, r, 1, np.float32)
  293. y = x[:, None]
  294. x0, y0 = 3 * sigma + 1, 3 * sigma + 1
  295. self.gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
  296. def __call__(self, records):
  297. kpts_lst = records['joints']
  298. mask_lst = records['mask']
  299. for idx, hmsize in enumerate(self.hmsize):
  300. mask = mask_lst[idx]
  301. kpts = kpts_lst[idx]
  302. heatmaps = np.zeros((self.num_joints, hmsize, hmsize))
  303. inds = np.where(kpts[..., 2] > 0)
  304. visible = kpts[inds].astype(np.int64)[..., :2]
  305. ul = np.round(visible - 3 * self.sigma - 1)
  306. br = np.round(visible + 3 * self.sigma + 2)
  307. sul = np.maximum(0, -ul)
  308. sbr = np.minimum(hmsize, br) - ul
  309. dul = np.clip(ul, 0, hmsize - 1)
  310. dbr = np.clip(br, 0, hmsize)
  311. for i in range(len(visible)):
  312. if visible[i][0] < 0 or visible[i][1] < 0 or visible[i][
  313. 0] >= hmsize or visible[i][1] >= hmsize:
  314. continue
  315. dx1, dy1 = dul[i]
  316. dx2, dy2 = dbr[i]
  317. sx1, sy1 = sul[i]
  318. sx2, sy2 = sbr[i]
  319. heatmaps[inds[1][i], dy1:dy2, dx1:dx2] = np.maximum(
  320. self.gaussian[sy1:sy2, sx1:sx2],
  321. heatmaps[inds[1][i], dy1:dy2, dx1:dx2])
  322. records['heatmap_gt{}x'.format(idx + 1)] = heatmaps
  323. records['mask_{}x'.format(idx + 1)] = mask
  324. del records['mask']
  325. return records
  326. @register_keypointop
  327. class RandomFlipHalfBodyTransform(object):
  328. """apply data augment to image and coords
  329. to achieve the flip, scale, rotate and half body transform effect for training image
  330. Args:
  331. trainsize (list):[w, h], Image target size
  332. upper_body_ids (list): The upper body joint ids
  333. flip_pairs (list): The left-right joints exchange order list
  334. pixel_std (int): The pixel std of the scale
  335. scale (float): The scale factor to transform the image
  336. rot (int): The rotate factor to transform the image
  337. num_joints_half_body (int): The joints threshold of the half body transform
  338. prob_half_body (float): The threshold of the half body transform
  339. flip (bool): Whether to flip the image
  340. Returns:
  341. records(dict): contain the image and coords after tranformed
  342. """
  343. def __init__(self,
  344. trainsize,
  345. upper_body_ids,
  346. flip_pairs,
  347. pixel_std,
  348. scale=0.35,
  349. rot=40,
  350. num_joints_half_body=8,
  351. prob_half_body=0.3,
  352. flip=True,
  353. rot_prob=0.6):
  354. super(RandomFlipHalfBodyTransform, self).__init__()
  355. self.trainsize = trainsize
  356. self.upper_body_ids = upper_body_ids
  357. self.flip_pairs = flip_pairs
  358. self.pixel_std = pixel_std
  359. self.scale = scale
  360. self.rot = rot
  361. self.num_joints_half_body = num_joints_half_body
  362. self.prob_half_body = prob_half_body
  363. self.flip = flip
  364. self.aspect_ratio = trainsize[0] * 1.0 / trainsize[1]
  365. self.rot_prob = rot_prob
  366. def halfbody_transform(self, joints, joints_vis):
  367. upper_joints = []
  368. lower_joints = []
  369. for joint_id in range(joints.shape[0]):
  370. if joints_vis[joint_id][0] > 0:
  371. if joint_id in self.upper_body_ids:
  372. upper_joints.append(joints[joint_id])
  373. else:
  374. lower_joints.append(joints[joint_id])
  375. if np.random.randn() < 0.5 and len(upper_joints) > 2:
  376. selected_joints = upper_joints
  377. else:
  378. selected_joints = lower_joints if len(
  379. lower_joints) > 2 else upper_joints
  380. if len(selected_joints) < 2:
  381. return None, None
  382. selected_joints = np.array(selected_joints, dtype=np.float32)
  383. center = selected_joints.mean(axis=0)[:2]
  384. left_top = np.amin(selected_joints, axis=0)
  385. right_bottom = np.amax(selected_joints, axis=0)
  386. w = right_bottom[0] - left_top[0]
  387. h = right_bottom[1] - left_top[1]
  388. if w > self.aspect_ratio * h:
  389. h = w * 1.0 / self.aspect_ratio
  390. elif w < self.aspect_ratio * h:
  391. w = h * self.aspect_ratio
  392. scale = np.array(
  393. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  394. dtype=np.float32)
  395. scale = scale * 1.5
  396. return center, scale
  397. def flip_joints(self, joints, joints_vis, width, matched_parts):
  398. joints[:, 0] = width - joints[:, 0] - 1
  399. for pair in matched_parts:
  400. joints[pair[0], :], joints[pair[1], :] = \
  401. joints[pair[1], :], joints[pair[0], :].copy()
  402. joints_vis[pair[0], :], joints_vis[pair[1], :] = \
  403. joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
  404. return joints * joints_vis, joints_vis
  405. def __call__(self, records):
  406. image = records['image']
  407. joints = records['joints']
  408. joints_vis = records['joints_vis']
  409. c = records['center']
  410. s = records['scale']
  411. r = 0
  412. if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body and
  413. np.random.rand() < self.prob_half_body):
  414. c_half_body, s_half_body = self.halfbody_transform(joints,
  415. joints_vis)
  416. if c_half_body is not None and s_half_body is not None:
  417. c, s = c_half_body, s_half_body
  418. sf = self.scale
  419. rf = self.rot
  420. s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
  421. r = np.clip(np.random.randn() * rf, -rf * 2,
  422. rf * 2) if np.random.random() <= self.rot_prob else 0
  423. if self.flip and np.random.random() <= 0.5:
  424. image = image[:, ::-1, :]
  425. joints, joints_vis = self.flip_joints(
  426. joints, joints_vis, image.shape[1], self.flip_pairs)
  427. c[0] = image.shape[1] - c[0] - 1
  428. records['image'] = image
  429. records['joints'] = joints
  430. records['joints_vis'] = joints_vis
  431. records['center'] = c
  432. records['scale'] = s
  433. records['rotate'] = r
  434. return records
  435. @register_keypointop
  436. class AugmentationbyInformantionDropping(object):
  437. """AID: Augmentation by Informantion Dropping. Please refer
  438. to https://arxiv.org/abs/2008.07139
  439. Args:
  440. prob_cutout (float): The probability of the Cutout augmentation.
  441. offset_factor (float): Offset factor of cutout center.
  442. num_patch (int): Number of patches to be cutout.
  443. records(dict): the dict contained the image and coords
  444. Returns:
  445. records (dict): contain the image and coords after tranformed
  446. """
  447. def __init__(self,
  448. trainsize,
  449. prob_cutout=0.0,
  450. offset_factor=0.2,
  451. num_patch=1):
  452. self.prob_cutout = prob_cutout
  453. self.offset_factor = offset_factor
  454. self.num_patch = num_patch
  455. self.trainsize = trainsize
  456. def _cutout(self, img, joints, joints_vis):
  457. height, width, _ = img.shape
  458. img = img.reshape((height * width, -1))
  459. feat_x_int = np.arange(0, width)
  460. feat_y_int = np.arange(0, height)
  461. feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)
  462. feat_x_int = feat_x_int.reshape((-1, ))
  463. feat_y_int = feat_y_int.reshape((-1, ))
  464. for _ in range(self.num_patch):
  465. vis_idx, _ = np.where(joints_vis > 0)
  466. occlusion_joint_id = np.random.choice(vis_idx)
  467. center = joints[occlusion_joint_id, 0:2]
  468. offset = np.random.randn(2) * self.trainsize[0] * self.offset_factor
  469. center = center + offset
  470. radius = np.random.uniform(0.1, 0.2) * self.trainsize[0]
  471. x_offset = (center[0] - feat_x_int) / radius
  472. y_offset = (center[1] - feat_y_int) / radius
  473. dis = x_offset**2 + y_offset**2
  474. keep_pos = np.where((dis <= 1) & (dis >= 0))[0]
  475. img[keep_pos, :] = 0
  476. img = img.reshape((height, width, -1))
  477. return img
  478. def __call__(self, records):
  479. img = records['image']
  480. joints = records['joints']
  481. joints_vis = records['joints_vis']
  482. if np.random.rand() < self.prob_cutout:
  483. img = self._cutout(img, joints, joints_vis)
  484. records['image'] = img
  485. return records
  486. @register_keypointop
  487. class TopDownAffine(object):
  488. """apply affine transform to image and coords
  489. Args:
  490. trainsize (list): [w, h], the standard size used to train
  491. use_udp (bool): whether to use Unbiased Data Processing.
  492. records(dict): the dict contained the image and coords
  493. Returns:
  494. records (dict): contain the image and coords after tranformed
  495. """
  496. def __init__(self, trainsize, use_udp=False):
  497. self.trainsize = trainsize
  498. self.use_udp = use_udp
  499. def __call__(self, records):
  500. image = records['image']
  501. joints = records['joints']
  502. joints_vis = records['joints_vis']
  503. rot = records['rotate'] if "rotate" in records else 0
  504. if self.use_udp:
  505. trans = get_warp_matrix(
  506. rot, records['center'] * 2.0,
  507. [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0],
  508. records['scale'] * 200.0)
  509. image = cv2.warpAffine(
  510. image,
  511. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  512. flags=cv2.INTER_LINEAR)
  513. joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(), trans)
  514. else:
  515. trans = get_affine_transform(records['center'], records['scale'] *
  516. 200, rot, self.trainsize)
  517. image = cv2.warpAffine(
  518. image,
  519. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  520. flags=cv2.INTER_LINEAR)
  521. for i in range(joints.shape[0]):
  522. if joints_vis[i, 0] > 0.0:
  523. joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
  524. records['image'] = image
  525. records['joints'] = joints
  526. return records
  527. @register_keypointop
  528. class SinglePoseAffine(object):
  529. """apply affine transform to image and coords
  530. Args:
  531. trainsize (list): [w, h], the standard size used to train
  532. use_udp (bool): whether to use Unbiased Data Processing.
  533. records(dict): the dict contained the image and coords
  534. Returns:
  535. records (dict): contain the image and coords after tranformed
  536. """
  537. def __init__(self,
  538. trainsize,
  539. rotate=[1.0, 30],
  540. scale=[1.0, 0.25],
  541. use_udp=False):
  542. self.trainsize = trainsize
  543. self.use_udp = use_udp
  544. self.rot_prob = rotate[0]
  545. self.rot_range = rotate[1]
  546. self.scale_prob = scale[0]
  547. self.scale_ratio = scale[1]
  548. def __call__(self, records):
  549. image = records['image']
  550. if 'joints_2d' in records:
  551. joints = records['joints_2d'] if 'joints_2d' in records else None
  552. joints_vis = records[
  553. 'joints_vis'] if 'joints_vis' in records else np.ones(
  554. (len(joints), 1))
  555. rot = 0
  556. s = 1.
  557. if np.random.random() < self.rot_prob:
  558. rot = np.clip(np.random.randn() * self.rot_range,
  559. -self.rot_range * 2, self.rot_range * 2)
  560. if np.random.random() < self.scale_prob:
  561. s = np.clip(np.random.randn() * self.scale_ratio + 1,
  562. 1 - self.scale_ratio, 1 + self.scale_ratio)
  563. if self.use_udp:
  564. trans = get_warp_matrix(
  565. rot,
  566. np.array(records['bbox_center']) * 2.0,
  567. [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0],
  568. records['bbox_scale'] * 200.0 * s)
  569. image = cv2.warpAffine(
  570. image,
  571. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  572. flags=cv2.INTER_LINEAR)
  573. if 'joints_2d' in records:
  574. joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(),
  575. trans)
  576. else:
  577. trans = get_affine_transform(
  578. np.array(records['bbox_center']),
  579. records['bbox_scale'] * s * 200, rot, self.trainsize)
  580. image = cv2.warpAffine(
  581. image,
  582. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  583. flags=cv2.INTER_LINEAR)
  584. if 'joints_2d' in records:
  585. for i in range(len(joints)):
  586. if joints_vis[i, 0] > 0.0:
  587. joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
  588. if 'joints_3d' in records:
  589. pose3d = records['joints_3d']
  590. if not rot == 0:
  591. trans_3djoints = np.eye(3)
  592. rot_rad = -rot * np.pi / 180
  593. sn, cs = np.sin(rot_rad), np.cos(rot_rad)
  594. trans_3djoints[0, :2] = [cs, -sn]
  595. trans_3djoints[1, :2] = [sn, cs]
  596. pose3d[:, :3] = np.einsum('ij,kj->ki', trans_3djoints,
  597. pose3d[:, :3])
  598. records['joints_3d'] = pose3d
  599. records['image'] = image
  600. if 'joints_2d' in records:
  601. records['joints_2d'] = joints
  602. return records
  603. @register_keypointop
  604. class NoiseJitter(object):
  605. """apply NoiseJitter to image
  606. Args:
  607. noise_factor (float): the noise factor ratio used to generate the jitter
  608. Returns:
  609. records (dict): contain the image and coords after tranformed
  610. """
  611. def __init__(self, noise_factor=0.4):
  612. self.noise_factor = noise_factor
  613. def __call__(self, records):
  614. self.pn = np.random.uniform(1 - self.noise_factor,
  615. 1 + self.noise_factor, 3)
  616. rgb_img = records['image']
  617. rgb_img[:, :, 0] = np.minimum(
  618. 255.0, np.maximum(0.0, rgb_img[:, :, 0] * self.pn[0]))
  619. rgb_img[:, :, 1] = np.minimum(
  620. 255.0, np.maximum(0.0, rgb_img[:, :, 1] * self.pn[1]))
  621. rgb_img[:, :, 2] = np.minimum(
  622. 255.0, np.maximum(0.0, rgb_img[:, :, 2] * self.pn[2]))
  623. records['image'] = rgb_img
  624. return records
  625. @register_keypointop
  626. class FlipPose(object):
  627. """random apply flip to image
  628. Args:
  629. noise_factor (float): the noise factor ratio used to generate the jitter
  630. Returns:
  631. records (dict): contain the image and coords after tranformed
  632. """
  633. def __init__(self, flip_prob=0.5, img_res=224, num_joints=14):
  634. self.flip_pob = flip_prob
  635. self.img_res = img_res
  636. if num_joints == 24:
  637. self.perm = [
  638. 5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17,
  639. 18, 19, 21, 20, 23, 22
  640. ]
  641. elif num_joints == 14:
  642. self.perm = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13]
  643. else:
  644. print("error num_joints in flip :{}".format(num_joints))
  645. def __call__(self, records):
  646. if np.random.random() < self.flip_pob:
  647. img = records['image']
  648. img = np.fliplr(img)
  649. if 'joints_2d' in records:
  650. joints_2d = records['joints_2d']
  651. joints_2d = joints_2d[self.perm]
  652. joints_2d[:, 0] = self.img_res - joints_2d[:, 0]
  653. records['joints_2d'] = joints_2d
  654. if 'joints_3d' in records:
  655. joints_3d = records['joints_3d']
  656. joints_3d = joints_3d[self.perm]
  657. joints_3d[:, 0] = -joints_3d[:, 0]
  658. records['joints_3d'] = joints_3d
  659. records['image'] = img
  660. return records
  661. @register_keypointop
  662. class TopDownEvalAffine(object):
  663. """apply affine transform to image and coords
  664. Args:
  665. trainsize (list): [w, h], the standard size used to train
  666. use_udp (bool): whether to use Unbiased Data Processing.
  667. records(dict): the dict contained the image and coords
  668. Returns:
  669. records (dict): contain the image and coords after tranformed
  670. """
  671. def __init__(self, trainsize, use_udp=False):
  672. self.trainsize = trainsize
  673. self.use_udp = use_udp
  674. def __call__(self, records):
  675. image = records['image']
  676. rot = 0
  677. imshape = records['im_shape'][::-1]
  678. center = imshape / 2.
  679. scale = imshape
  680. if self.use_udp:
  681. trans = get_warp_matrix(
  682. rot, center * 2.0,
  683. [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
  684. image = cv2.warpAffine(
  685. image,
  686. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  687. flags=cv2.INTER_LINEAR)
  688. else:
  689. trans = get_affine_transform(center, scale, rot, self.trainsize)
  690. image = cv2.warpAffine(
  691. image,
  692. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  693. flags=cv2.INTER_LINEAR)
  694. records['image'] = image
  695. return records
  696. @register_keypointop
  697. class ToHeatmapsTopDown(object):
  698. """to generate the gaussin heatmaps of keypoint for heatmap loss
  699. Args:
  700. hmsize (list): [w, h] output heatmap's size
  701. sigma (float): the std of gaussin kernel genereted
  702. records(dict): the dict contained the image and coords
  703. Returns:
  704. records (dict): contain the heatmaps used to heatmaploss
  705. """
  706. def __init__(self, hmsize, sigma):
  707. super(ToHeatmapsTopDown, self).__init__()
  708. self.hmsize = np.array(hmsize)
  709. self.sigma = sigma
  710. def __call__(self, records):
  711. """refer to
  712. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  713. Copyright (c) Microsoft, under the MIT License.
  714. """
  715. joints = records['joints']
  716. joints_vis = records['joints_vis']
  717. num_joints = joints.shape[0]
  718. image_size = np.array(
  719. [records['image'].shape[1], records['image'].shape[0]])
  720. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  721. target_weight[:, 0] = joints_vis[:, 0]
  722. target = np.zeros(
  723. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  724. tmp_size = self.sigma * 3
  725. feat_stride = image_size / self.hmsize
  726. for joint_id in range(num_joints):
  727. mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
  728. mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
  729. # Check that any part of the gaussian is in-bounds
  730. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  731. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  732. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  733. 0] < 0 or br[1] < 0:
  734. # If not, just return the image as is
  735. target_weight[joint_id] = 0
  736. continue
  737. # # Generate gaussian
  738. size = 2 * tmp_size + 1
  739. x = np.arange(0, size, 1, np.float32)
  740. y = x[:, np.newaxis]
  741. x0 = y0 = size // 2
  742. # The gaussian is not normalized, we want the center value to equal 1
  743. g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * self.sigma**2))
  744. # Usable gaussian range
  745. g_x = max(0, -ul[0]), min(br[0], self.hmsize[0]) - ul[0]
  746. g_y = max(0, -ul[1]), min(br[1], self.hmsize[1]) - ul[1]
  747. # Image range
  748. img_x = max(0, ul[0]), min(br[0], self.hmsize[0])
  749. img_y = max(0, ul[1]), min(br[1], self.hmsize[1])
  750. v = target_weight[joint_id]
  751. if v > 0.5:
  752. target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[
  753. 0]:g_y[1], g_x[0]:g_x[1]]
  754. records['target'] = target
  755. records['target_weight'] = target_weight
  756. del records['joints'], records['joints_vis']
  757. return records
  758. @register_keypointop
  759. class ToHeatmapsTopDown_DARK(object):
  760. """to generate the gaussin heatmaps of keypoint for heatmap loss
  761. Args:
  762. hmsize (list): [w, h] output heatmap's size
  763. sigma (float): the std of gaussin kernel genereted
  764. records(dict): the dict contained the image and coords
  765. Returns:
  766. records (dict): contain the heatmaps used to heatmaploss
  767. """
  768. def __init__(self, hmsize, sigma):
  769. super(ToHeatmapsTopDown_DARK, self).__init__()
  770. self.hmsize = np.array(hmsize)
  771. self.sigma = sigma
  772. def __call__(self, records):
  773. joints = records['joints']
  774. joints_vis = records['joints_vis']
  775. num_joints = joints.shape[0]
  776. image_size = np.array(
  777. [records['image'].shape[1], records['image'].shape[0]])
  778. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  779. target_weight[:, 0] = joints_vis[:, 0]
  780. target = np.zeros(
  781. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  782. tmp_size = self.sigma * 3
  783. feat_stride = image_size / self.hmsize
  784. for joint_id in range(num_joints):
  785. mu_x = joints[joint_id][0] / feat_stride[0]
  786. mu_y = joints[joint_id][1] / feat_stride[1]
  787. # Check that any part of the gaussian is in-bounds
  788. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  789. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  790. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  791. 0] < 0 or br[1] < 0:
  792. # If not, just return the image as is
  793. target_weight[joint_id] = 0
  794. continue
  795. x = np.arange(0, self.hmsize[0], 1, np.float32)
  796. y = np.arange(0, self.hmsize[1], 1, np.float32)
  797. y = y[:, np.newaxis]
  798. v = target_weight[joint_id]
  799. if v > 0.5:
  800. target[joint_id] = np.exp(-(
  801. (x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2))
  802. records['target'] = target
  803. records['target_weight'] = target_weight
  804. del records['joints'], records['joints_vis']
  805. return records
  806. @register_keypointop
  807. class ToHeatmapsTopDown_UDP(object):
  808. """This code is based on:
  809. https://github.com/HuangJunJie2017/UDP-Pose/blob/master/deep-high-resolution-net.pytorch/lib/dataset/JointsDataset.py
  810. to generate the gaussian heatmaps of keypoint for heatmap loss.
  811. ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing
  812. for Human Pose Estimation (CVPR 2020).
  813. Args:
  814. hmsize (list): [w, h] output heatmap's size
  815. sigma (float): the std of gaussin kernel genereted
  816. records(dict): the dict contained the image and coords
  817. Returns:
  818. records (dict): contain the heatmaps used to heatmaploss
  819. """
  820. def __init__(self, hmsize, sigma):
  821. super(ToHeatmapsTopDown_UDP, self).__init__()
  822. self.hmsize = np.array(hmsize)
  823. self.sigma = sigma
  824. def __call__(self, records):
  825. joints = records['joints']
  826. joints_vis = records['joints_vis']
  827. num_joints = joints.shape[0]
  828. image_size = np.array(
  829. [records['image'].shape[1], records['image'].shape[0]])
  830. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  831. target_weight[:, 0] = joints_vis[:, 0]
  832. target = np.zeros(
  833. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  834. tmp_size = self.sigma * 3
  835. size = 2 * tmp_size + 1
  836. x = np.arange(0, size, 1, np.float32)
  837. y = x[:, None]
  838. feat_stride = (image_size - 1.0) / (self.hmsize - 1.0)
  839. for joint_id in range(num_joints):
  840. mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
  841. mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
  842. # Check that any part of the gaussian is in-bounds
  843. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  844. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  845. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  846. 0] < 0 or br[1] < 0:
  847. # If not, just return the image as is
  848. target_weight[joint_id] = 0
  849. continue
  850. mu_x_ac = joints[joint_id][0] / feat_stride[0]
  851. mu_y_ac = joints[joint_id][1] / feat_stride[1]
  852. x0 = y0 = size // 2
  853. x0 += mu_x_ac - mu_x
  854. y0 += mu_y_ac - mu_y
  855. g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * self.sigma**2))
  856. # Usable gaussian range
  857. g_x = max(0, -ul[0]), min(br[0], self.hmsize[0]) - ul[0]
  858. g_y = max(0, -ul[1]), min(br[1], self.hmsize[1]) - ul[1]
  859. # Image range
  860. img_x = max(0, ul[0]), min(br[0], self.hmsize[0])
  861. img_y = max(0, ul[1]), min(br[1], self.hmsize[1])
  862. v = target_weight[joint_id]
  863. if v > 0.5:
  864. target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[
  865. 0]:g_y[1], g_x[0]:g_x[1]]
  866. records['target'] = target
  867. records['target_weight'] = target_weight
  868. del records['joints'], records['joints_vis']
  869. return records