test_yolov3_loss.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import division
  15. import unittest
  16. import paddle
  17. import paddle.nn.functional as F
  18. # add python path of PadleDetection to sys.path
  19. import os
  20. import sys
  21. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
  22. if parent_path not in sys.path:
  23. sys.path.append(parent_path)
  24. from ppdet.modeling.losses import YOLOv3Loss
  25. from ppdet.data.transform.op_helper import jaccard_overlap
  26. from ppdet.modeling.bbox_utils import iou_similarity
  27. import numpy as np
  28. np.random.seed(0)
  29. def _split_output(output, an_num, num_classes):
  30. """
  31. Split output feature map to x, y, w, h, objectness, classification
  32. along channel dimension
  33. """
  34. x = paddle.strided_slice(
  35. output,
  36. axes=[1],
  37. starts=[0],
  38. ends=[output.shape[1]],
  39. strides=[5 + num_classes])
  40. y = paddle.strided_slice(
  41. output,
  42. axes=[1],
  43. starts=[1],
  44. ends=[output.shape[1]],
  45. strides=[5 + num_classes])
  46. w = paddle.strided_slice(
  47. output,
  48. axes=[1],
  49. starts=[2],
  50. ends=[output.shape[1]],
  51. strides=[5 + num_classes])
  52. h = paddle.strided_slice(
  53. output,
  54. axes=[1],
  55. starts=[3],
  56. ends=[output.shape[1]],
  57. strides=[5 + num_classes])
  58. obj = paddle.strided_slice(
  59. output,
  60. axes=[1],
  61. starts=[4],
  62. ends=[output.shape[1]],
  63. strides=[5 + num_classes])
  64. clss = []
  65. stride = output.shape[1] // an_num
  66. for m in range(an_num):
  67. clss.append(
  68. paddle.slice(
  69. output,
  70. axes=[1],
  71. starts=[stride * m + 5],
  72. ends=[stride * m + 5 + num_classes]))
  73. cls = paddle.transpose(paddle.stack(clss, axis=1), perm=[0, 1, 3, 4, 2])
  74. return (x, y, w, h, obj, cls)
  75. def _split_target(target):
  76. """
  77. split target to x, y, w, h, objectness, classification
  78. along dimension 2
  79. target is in shape [N, an_num, 6 + class_num, H, W]
  80. """
  81. tx = target[:, :, 0, :, :]
  82. ty = target[:, :, 1, :, :]
  83. tw = target[:, :, 2, :, :]
  84. th = target[:, :, 3, :, :]
  85. tscale = target[:, :, 4, :, :]
  86. tobj = target[:, :, 5, :, :]
  87. tcls = paddle.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
  88. tcls.stop_gradient = True
  89. return (tx, ty, tw, th, tscale, tobj, tcls)
  90. def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
  91. downsample, ignore_thresh, scale_x_y):
  92. # A prediction bbox overlap any gt_bbox over ignore_thresh,
  93. # objectness loss will be ignored, process as follows:
  94. # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
  95. # NOTE: img_size is set as 1.0 to get noramlized pred bbox
  96. bbox, prob = paddle.vision.ops.yolo_box(
  97. x=output,
  98. img_size=paddle.ones(
  99. shape=[batch_size, 2], dtype="int32"),
  100. anchors=anchors,
  101. class_num=num_classes,
  102. conf_thresh=0.,
  103. downsample_ratio=downsample,
  104. clip_bbox=False,
  105. scale_x_y=scale_x_y)
  106. # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
  107. # and gt bbox in each sample
  108. if batch_size > 1:
  109. preds = paddle.split(bbox, batch_size, axis=0)
  110. gts = paddle.split(gt_box, batch_size, axis=0)
  111. else:
  112. preds = [bbox]
  113. gts = [gt_box]
  114. probs = [prob]
  115. ious = []
  116. for pred, gt in zip(preds, gts):
  117. def box_xywh2xyxy(box):
  118. x = box[:, 0]
  119. y = box[:, 1]
  120. w = box[:, 2]
  121. h = box[:, 3]
  122. return paddle.stack(
  123. [
  124. x - w / 2.,
  125. y - h / 2.,
  126. x + w / 2.,
  127. y + h / 2.,
  128. ], axis=1)
  129. pred = paddle.squeeze(pred, axis=[0])
  130. gt = box_xywh2xyxy(paddle.squeeze(gt, axis=[0]))
  131. ious.append(iou_similarity(pred, gt))
  132. iou = paddle.stack(ious, axis=0)
  133. # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
  134. # Get obj_mask by tobj(holds gt_score), calculate objectness loss
  135. max_iou = paddle.max(iou, axis=-1)
  136. iou_mask = paddle.cast(max_iou <= ignore_thresh, dtype="float32")
  137. output_shape = paddle.shape(output)
  138. an_num = len(anchors) // 2
  139. iou_mask = paddle.reshape(iou_mask, (-1, an_num, output_shape[2],
  140. output_shape[3]))
  141. iou_mask.stop_gradient = True
  142. # NOTE: tobj holds gt_score, obj_mask holds object existence mask
  143. obj_mask = paddle.cast(tobj > 0., dtype="float32")
  144. obj_mask.stop_gradient = True
  145. # For positive objectness grids, objectness loss should be calculated
  146. # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
  147. obj_sigmoid = F.sigmoid(obj)
  148. loss_obj = F.binary_cross_entropy(obj_sigmoid, obj_mask, reduction='none')
  149. loss_obj_pos = paddle.sum(loss_obj * tobj, axis=[1, 2, 3])
  150. loss_obj_neg = paddle.sum(loss_obj * (1.0 - obj_mask) * iou_mask,
  151. axis=[1, 2, 3])
  152. return loss_obj_pos, loss_obj_neg
  153. def fine_grained_loss(output,
  154. target,
  155. gt_box,
  156. batch_size,
  157. num_classes,
  158. anchors,
  159. ignore_thresh,
  160. downsample,
  161. scale_x_y=1.,
  162. eps=1e-10):
  163. an_num = len(anchors) // 2
  164. x, y, w, h, obj, cls = _split_output(output, an_num, num_classes)
  165. tx, ty, tw, th, tscale, tobj, tcls = _split_target(target)
  166. tscale_tobj = tscale * tobj
  167. scale_x_y = scale_x_y
  168. if (abs(scale_x_y - 1.0) < eps):
  169. x = F.sigmoid(x)
  170. y = F.sigmoid(y)
  171. loss_x = F.binary_cross_entropy(x, tx, reduction='none') * tscale_tobj
  172. loss_x = paddle.sum(loss_x, axis=[1, 2, 3])
  173. loss_y = F.binary_cross_entropy(y, ty, reduction='none') * tscale_tobj
  174. loss_y = paddle.sum(loss_y, axis=[1, 2, 3])
  175. else:
  176. dx = scale_x_y * F.sigmoid(x) - 0.5 * (scale_x_y - 1.0)
  177. dy = scale_x_y * F.sigmoid(y) - 0.5 * (scale_x_y - 1.0)
  178. loss_x = paddle.abs(dx - tx) * tscale_tobj
  179. loss_x = paddle.sum(loss_x, axis=[1, 2, 3])
  180. loss_y = paddle.abs(dy - ty) * tscale_tobj
  181. loss_y = paddle.sum(loss_y, axis=[1, 2, 3])
  182. # NOTE: we refined loss function of (w, h) as L1Loss
  183. loss_w = paddle.abs(w - tw) * tscale_tobj
  184. loss_w = paddle.sum(loss_w, axis=[1, 2, 3])
  185. loss_h = paddle.abs(h - th) * tscale_tobj
  186. loss_h = paddle.sum(loss_h, axis=[1, 2, 3])
  187. loss_obj_pos, loss_obj_neg = _calc_obj_loss(
  188. output, obj, tobj, gt_box, batch_size, anchors, num_classes, downsample,
  189. ignore_thresh, scale_x_y)
  190. cls = F.sigmoid(cls)
  191. loss_cls = F.binary_cross_entropy(cls, tcls, reduction='none')
  192. tobj = paddle.unsqueeze(tobj, axis=-1)
  193. loss_cls = paddle.multiply(loss_cls, tobj)
  194. loss_cls = paddle.sum(loss_cls, axis=[1, 2, 3, 4])
  195. loss_xys = paddle.mean(loss_x + loss_y)
  196. loss_whs = paddle.mean(loss_w + loss_h)
  197. loss_objs = paddle.mean(loss_obj_pos + loss_obj_neg)
  198. loss_clss = paddle.mean(loss_cls)
  199. losses_all = {
  200. "loss_xy": paddle.sum(loss_xys),
  201. "loss_wh": paddle.sum(loss_whs),
  202. "loss_loc": paddle.sum(loss_xys) + paddle.sum(loss_whs),
  203. "loss_obj": paddle.sum(loss_objs),
  204. "loss_cls": paddle.sum(loss_clss),
  205. }
  206. return losses_all, x, y, tx, ty
  207. def gt2yolotarget(gt_bbox, gt_class, gt_score, anchors, mask, num_classes, size,
  208. stride):
  209. grid_h, grid_w = size
  210. h, w = grid_h * stride, grid_w * stride
  211. an_hw = np.array(anchors) / np.array([[w, h]])
  212. target = np.zeros(
  213. (len(mask), 6 + num_classes, grid_h, grid_w), dtype=np.float32)
  214. for b in range(gt_bbox.shape[0]):
  215. gx, gy, gw, gh = gt_bbox[b, :]
  216. cls = gt_class[b]
  217. score = gt_score[b]
  218. if gw <= 0. or gh <= 0. or score <= 0.:
  219. continue
  220. # find best match anchor index
  221. best_iou = 0.
  222. best_idx = -1
  223. for an_idx in range(an_hw.shape[0]):
  224. iou = jaccard_overlap([0., 0., gw, gh],
  225. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  226. if iou > best_iou:
  227. best_iou = iou
  228. best_idx = an_idx
  229. gi = int(gx * grid_w)
  230. gj = int(gy * grid_h)
  231. # gtbox should be regresed in this layes if best match
  232. # anchor index in anchor mask of this layer
  233. if best_idx in mask:
  234. best_n = mask.index(best_idx)
  235. # x, y, w, h, scale
  236. target[best_n, 0, gj, gi] = gx * grid_w - gi
  237. target[best_n, 1, gj, gi] = gy * grid_h - gj
  238. target[best_n, 2, gj, gi] = np.log(gw * w / anchors[best_idx][0])
  239. target[best_n, 3, gj, gi] = np.log(gh * h / anchors[best_idx][1])
  240. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  241. # objectness record gt_score
  242. # if target[best_n, 5, gj, gi] > 0:
  243. # print('find 1 duplicate')
  244. target[best_n, 5, gj, gi] = score
  245. # classification
  246. target[best_n, 6 + cls, gj, gi] = 1.
  247. return target
  248. class TestYolov3LossOp(unittest.TestCase):
  249. def setUp(self):
  250. self.initTestCase()
  251. x = np.random.uniform(0, 1, self.x_shape).astype('float64')
  252. gtbox = np.random.random(size=self.gtbox_shape).astype('float64')
  253. gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2])
  254. gtmask = np.random.randint(0, 2, self.gtbox_shape[:2])
  255. gtbox = gtbox * gtmask[:, :, np.newaxis]
  256. gtlabel = gtlabel * gtmask
  257. gtscore = np.ones(self.gtbox_shape[:2]).astype('float64')
  258. if self.gtscore:
  259. gtscore = np.random.random(self.gtbox_shape[:2]).astype('float64')
  260. target = []
  261. for box, label, score in zip(gtbox, gtlabel, gtscore):
  262. target.append(
  263. gt2yolotarget(box, label, score, self.anchors, self.anchor_mask,
  264. self.class_num, (self.h, self.w
  265. ), self.downsample_ratio))
  266. self.target = np.array(target).astype('float64')
  267. self.mask_anchors = []
  268. for i in self.anchor_mask:
  269. self.mask_anchors.extend(self.anchors[i])
  270. self.x = x
  271. self.gtbox = gtbox
  272. self.gtlabel = gtlabel
  273. self.gtscore = gtscore
  274. def initTestCase(self):
  275. self.b = 8
  276. self.h = 19
  277. self.w = 19
  278. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  279. [59, 119], [116, 90], [156, 198], [373, 326]]
  280. self.anchor_mask = [6, 7, 8]
  281. self.na = len(self.anchor_mask)
  282. self.class_num = 80
  283. self.ignore_thresh = 0.7
  284. self.downsample_ratio = 32
  285. self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
  286. self.h, self.w)
  287. self.gtbox_shape = (self.b, 40, 4)
  288. self.gtscore = True
  289. self.use_label_smooth = False
  290. self.scale_x_y = 1.
  291. def test_loss(self):
  292. x, gtbox, gtlabel, gtscore, target = self.x, self.gtbox, self.gtlabel, self.gtscore, self.target
  293. yolo_loss = YOLOv3Loss(
  294. ignore_thresh=self.ignore_thresh,
  295. label_smooth=self.use_label_smooth,
  296. num_classes=self.class_num,
  297. downsample=self.downsample_ratio,
  298. scale_x_y=self.scale_x_y)
  299. x = paddle.to_tensor(x.astype(np.float32))
  300. gtbox = paddle.to_tensor(gtbox.astype(np.float32))
  301. gtlabel = paddle.to_tensor(gtlabel.astype(np.float32))
  302. gtscore = paddle.to_tensor(gtscore.astype(np.float32))
  303. t = paddle.to_tensor(target.astype(np.float32))
  304. anchor = [self.anchors[i] for i in self.anchor_mask]
  305. (yolo_loss1, px, py, tx, ty) = fine_grained_loss(
  306. output=x,
  307. target=t,
  308. gt_box=gtbox,
  309. batch_size=self.b,
  310. num_classes=self.class_num,
  311. anchors=self.mask_anchors,
  312. ignore_thresh=self.ignore_thresh,
  313. downsample=self.downsample_ratio,
  314. scale_x_y=self.scale_x_y)
  315. yolo_loss2 = yolo_loss.yolov3_loss(
  316. x, t, gtbox, anchor, self.downsample_ratio, self.scale_x_y)
  317. for k in yolo_loss2:
  318. self.assertAlmostEqual(
  319. float(yolo_loss1[k]), float(yolo_loss2[k]), delta=1e-2, msg=k)
  320. class TestYolov3LossNoGTScore(TestYolov3LossOp):
  321. def initTestCase(self):
  322. self.b = 1
  323. self.h = 76
  324. self.w = 76
  325. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  326. [59, 119], [116, 90], [156, 198], [373, 326]]
  327. self.anchor_mask = [0, 1, 2]
  328. self.na = len(self.anchor_mask)
  329. self.class_num = 80
  330. self.ignore_thresh = 0.7
  331. self.downsample_ratio = 8
  332. self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
  333. self.h, self.w)
  334. self.gtbox_shape = (self.b, 40, 4)
  335. self.gtscore = False
  336. self.use_label_smooth = False
  337. self.scale_x_y = 1.
  338. class TestYolov3LossWithScaleXY(TestYolov3LossOp):
  339. def initTestCase(self):
  340. self.b = 5
  341. self.h = 38
  342. self.w = 38
  343. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  344. [59, 119], [116, 90], [156, 198], [373, 326]]
  345. self.anchor_mask = [3, 4, 5]
  346. self.na = len(self.anchor_mask)
  347. self.class_num = 80
  348. self.ignore_thresh = 0.7
  349. self.downsample_ratio = 16
  350. self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
  351. self.h, self.w)
  352. self.gtbox_shape = (self.b, 40, 4)
  353. self.gtscore = True
  354. self.use_label_smooth = False
  355. self.scale_x_y = 1.2
  356. if __name__ == "__main__":
  357. unittest.main()