12345678910111213141516171819202122232425262728293031323334353637 |
- import torch
- import torch.nn as nn
- class DownsampleA(nn.Module):
- def __init__(self, nIn, nOut, stride):
- super(DownsampleA, self).__init__()
- assert stride == 2
- self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
- def forward(self, x):
- x = self.avg(x)
- return torch.cat((x, x.mul(0)), 1)
- class DownsampleC(nn.Module):
- def __init__(self, nIn, nOut, stride):
- super(DownsampleC, self).__init__()
- assert stride != 1 or nIn != nOut
- self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
- def forward(self, x):
- x = self.conv(x)
- return x
- class DownsampleD(nn.Module):
- def __init__(self, nIn, nOut, stride):
- super(DownsampleD, self).__init__()
- assert stride == 2
- self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
- self.bn = nn.BatchNorm2d(nOut)
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- return x
|