resnet32.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # This is someone elses implementation of resnet optimized for CIFAR; I can't seem to find the repository again to reference the work.
  2. # I will keep on looking.
  3. import math
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn import init
  7. from .res_utils import DownsampleA
  8. class ResNetBasicblock(nn.Module):
  9. expansion = 1
  10. """
  11. RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua)
  12. """
  13. def __init__(self, inplanes, planes, stride=1, downsample=None):
  14. super(ResNetBasicblock, self).__init__()
  15. self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
  16. self.bn_a = nn.BatchNorm2d(planes)
  17. self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  18. self.bn_b = nn.BatchNorm2d(planes)
  19. self.downsample = downsample
  20. self.featureSize = 64
  21. def forward(self, x):
  22. residual = x
  23. basicblock = self.conv_a(x)
  24. basicblock = self.bn_a(basicblock)
  25. basicblock = F.relu(basicblock, inplace=True)
  26. basicblock = self.conv_b(basicblock)
  27. basicblock = self.bn_b(basicblock)
  28. if self.downsample is not None:
  29. residual = self.downsample(x)
  30. return F.relu(residual + basicblock, inplace=True)
  31. class CifarResNet(nn.Module):
  32. """
  33. ResNet optimized for the Cifar Dataset, as specified in
  34. https://arxiv.org/abs/1512.03385.pdf
  35. """
  36. def __init__(self, block, depth, num_classes, channels=3):
  37. """ Constructor
  38. Args:
  39. depth: number of layers.
  40. num_classes: number of classes
  41. base_width: base width
  42. """
  43. super(CifarResNet, self).__init__()
  44. self.featureSize = 64
  45. # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
  46. assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
  47. layer_blocks = (depth - 2) // 6
  48. self.num_classes = num_classes
  49. self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
  50. self.bn_1 = nn.BatchNorm2d(16)
  51. self.inplanes = 16
  52. self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
  53. self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
  54. self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
  55. self.avgpool = nn.AvgPool2d(8)
  56. self.fc = nn.Linear(64 * block.expansion, num_classes)
  57. self.fc2 = nn.Linear(64 * block.expansion, 100)
  58. for m in self.modules():
  59. if isinstance(m, nn.Conv2d):
  60. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  61. m.weight.data.normal_(0, math.sqrt(2. / n))
  62. # m.bias.data.zero_()
  63. elif isinstance(m, nn.BatchNorm2d):
  64. m.weight.data.fill_(1)
  65. m.bias.data.zero_()
  66. elif isinstance(m, nn.Linear):
  67. init.kaiming_normal(m.weight)
  68. m.bias.data.zero_()
  69. def _make_layer(self, block, planes, blocks, stride=1):
  70. downsample = None
  71. if stride != 1 or self.inplanes != planes * block.expansion:
  72. downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
  73. layers = []
  74. layers.append(block(self.inplanes, planes, stride, downsample))
  75. self.inplanes = planes * block.expansion
  76. for i in range(1, blocks):
  77. layers.append(block(self.inplanes, planes))
  78. return nn.Sequential(*layers)
  79. def forward(self, x, pretrain:bool=False):
  80. x = self.conv_1_3x3(x)
  81. x = F.relu(self.bn_1(x), inplace=True)
  82. x = self.stage_1(x)
  83. x = self.stage_2(x)
  84. x = self.stage_3(x)
  85. x = self.avgpool(x)
  86. x = x.view(x.size(0), -1)
  87. if pretrain:
  88. return self.fc2(x)
  89. x = self.fc(x)
  90. return x
  91. def resnet20(num_classes=10):
  92. """Constructs a ResNet-20 model for CIFAR-10 (by default)
  93. Args:
  94. num_classes (uint): number of classes
  95. """
  96. model = CifarResNet(ResNetBasicblock, 20, num_classes)
  97. return model
  98. def resnet8(num_classes=10):
  99. """Constructs a ResNet-20 model for CIFAR-10 (by default)
  100. Args:
  101. num_classes (uint): number of classes
  102. """
  103. model = CifarResNet(ResNetBasicblock, 8, num_classes, 3)
  104. return model
  105. def resnet20mnist(num_classes=10):
  106. """Constructs a ResNet-20 model for CIFAR-10 (by default)
  107. Args:
  108. num_classes (uint): number of classes
  109. """
  110. model = CifarResNet(ResNetBasicblock, 20, num_classes, 1)
  111. return model
  112. def resnet32mnist(num_classes=10, channels=1):
  113. model = CifarResNet(ResNetBasicblock, 32, num_classes, channels)
  114. return model
  115. def resnet32(num_classes=10):
  116. """Constructs a ResNet-32 model for CIFAR-10 (by default)
  117. Args:
  118. num_classes (uint): number of classes
  119. """
  120. model = CifarResNet(ResNetBasicblock, 32, num_classes)
  121. return model
  122. def resnet44(num_classes=10):
  123. """Constructs a ResNet-44 model for CIFAR-10 (by default)
  124. Args:
  125. num_classes (uint): number of classes
  126. """
  127. model = CifarResNet(ResNetBasicblock, 44, num_classes)
  128. return model
  129. def resnet56(num_classes=10):
  130. """Constructs a ResNet-56 model for CIFAR-10 (by default)
  131. Args:
  132. num_classes (uint): number of classes
  133. """
  134. model = CifarResNet(ResNetBasicblock, 56, num_classes)
  135. return model
  136. def resnet110(num_classes=10):
  137. """Constructs a ResNet-110 model for CIFAR-10 (by default)
  138. Args:
  139. num_classes (uint): number of classes
  140. """
  141. model = CifarResNet(ResNetBasicblock, 110, num_classes)
  142. return model