res_utils.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. import torch.nn as nn
  3. class DownsampleA(nn.Module):
  4. def __init__(self, nIn, nOut, stride):
  5. super(DownsampleA, self).__init__()
  6. assert stride == 2
  7. self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
  8. def forward(self, x):
  9. x = self.avg(x)
  10. return torch.cat((x, x.mul(0)), 1)
  11. class DownsampleC(nn.Module):
  12. def __init__(self, nIn, nOut, stride):
  13. super(DownsampleC, self).__init__()
  14. assert stride != 1 or nIn != nOut
  15. self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
  16. def forward(self, x):
  17. x = self.conv(x)
  18. return x
  19. class DownsampleD(nn.Module):
  20. def __init__(self, nIn, nOut, stride):
  21. super(DownsampleD, self).__init__()
  22. assert stride == 2
  23. self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
  24. self.bn = nn.BatchNorm2d(nOut)
  25. def forward(self, x):
  26. x = self.conv(x)
  27. x = self.bn(x)
  28. return x