supcontrast.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. import random
  21. from ppdet.core.workspace import register
  22. __all__ = ['SupContrast']
  23. @register
  24. class SupContrast(nn.Layer):
  25. __shared__ = [
  26. 'num_classes'
  27. ]
  28. def __init__(self, num_classes=80, temperature=2.5, sample_num=4096, thresh=0.75):
  29. super(SupContrast, self).__init__()
  30. self.num_classes = num_classes
  31. self.temperature = temperature
  32. self.sample_num = sample_num
  33. self.thresh = thresh
  34. def forward(self, features, labels, scores):
  35. assert features.shape[0] == labels.shape[0] == scores.shape[0]
  36. positive_mask = (labels < self.num_classes)
  37. positive_features, positive_labels, positive_scores = features[positive_mask], labels[positive_mask], \
  38. scores[positive_mask]
  39. negative_mask = (labels == self.num_classes)
  40. negative_features, negative_labels, negative_scores = features[negative_mask], labels[negative_mask], \
  41. scores[negative_mask]
  42. N = negative_features.shape[0]
  43. S = self.sample_num - positive_mask.sum()
  44. index = paddle.to_tensor(random.sample(range(N), int(S)), dtype='int32')
  45. negative_features = paddle.index_select(x=negative_features, index=index, axis=0)
  46. negative_labels = paddle.index_select(x=negative_labels, index=index, axis=0)
  47. negative_scores = paddle.index_select(x=negative_scores, index=index, axis=0)
  48. features = paddle.concat([positive_features, negative_features], 0)
  49. labels = paddle.concat([positive_labels, negative_labels], 0)
  50. scores = paddle.concat([positive_scores, negative_scores], 0)
  51. if len(labels.shape) == 1:
  52. labels = labels.reshape([-1, 1])
  53. label_mask = paddle.equal(labels, labels.T).detach()
  54. similarity = (paddle.matmul(features, features.T) / self.temperature)
  55. sim_row_max = paddle.max(similarity, axis=1, keepdim=True)
  56. similarity = similarity - sim_row_max
  57. logits_mask = paddle.ones_like(similarity).detach()
  58. logits_mask.fill_diagonal_(0)
  59. exp_sim = paddle.exp(similarity) * logits_mask
  60. log_prob = similarity - paddle.log(exp_sim.sum(axis=1, keepdim=True))
  61. per_label_log_prob = (log_prob * logits_mask * label_mask).sum(1) / label_mask.sum(1)
  62. keep = scores > self.thresh
  63. per_label_log_prob = per_label_log_prob[keep]
  64. loss = -per_label_log_prob
  65. return loss.mean()