detr.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from .meta_arch import BaseArch
  19. from ppdet.core.workspace import register, create
  20. __all__ = ['DETR']
  21. @register
  22. class DETR(BaseArch):
  23. __category__ = 'architecture'
  24. __inject__ = ['post_process']
  25. __shared__ = ['exclude_post_process']
  26. def __init__(self,
  27. backbone,
  28. transformer,
  29. detr_head,
  30. post_process='DETRBBoxPostProcess',
  31. exclude_post_process=False):
  32. super(DETR, self).__init__()
  33. self.backbone = backbone
  34. self.transformer = transformer
  35. self.detr_head = detr_head
  36. self.post_process = post_process
  37. self.exclude_post_process = exclude_post_process
  38. @classmethod
  39. def from_config(cls, cfg, *args, **kwargs):
  40. # backbone
  41. backbone = create(cfg['backbone'])
  42. # transformer
  43. kwargs = {'input_shape': backbone.out_shape}
  44. transformer = create(cfg['transformer'], **kwargs)
  45. # head
  46. kwargs = {
  47. 'hidden_dim': transformer.hidden_dim,
  48. 'nhead': transformer.nhead,
  49. 'input_shape': backbone.out_shape
  50. }
  51. detr_head = create(cfg['detr_head'], **kwargs)
  52. return {
  53. 'backbone': backbone,
  54. 'transformer': transformer,
  55. "detr_head": detr_head,
  56. }
  57. def _forward(self):
  58. # Backbone
  59. body_feats = self.backbone(self.inputs)
  60. # Transformer
  61. pad_mask = self.inputs['pad_mask'] if self.training else None
  62. out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
  63. # DETR Head
  64. if self.training:
  65. return self.detr_head(out_transformer, body_feats, self.inputs)
  66. else:
  67. preds = self.detr_head(out_transformer, body_feats)
  68. if self.exclude_post_process:
  69. bboxes, logits, masks = preds
  70. return bboxes, logits
  71. else:
  72. bbox, bbox_num = self.post_process(
  73. preds, self.inputs['im_shape'], self.inputs['scale_factor'])
  74. return bbox, bbox_num
  75. def get_loss(self):
  76. losses = self._forward()
  77. losses.update({
  78. 'loss':
  79. paddle.add_n([v for k, v in losses.items() if 'log' not in k])
  80. })
  81. return losses
  82. def get_pred(self):
  83. bbox_pred, bbox_num = self._forward()
  84. output = {
  85. "bbox": bbox_pred,
  86. "bbox_num": bbox_num,
  87. }
  88. return output