decoder.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Decoder(nn.Module):
  5. """
  6. Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
  7. Input:
  8. x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
  9. x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
  10. x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
  11. x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
  12. x0: (B, C, H, W) feature map at full resolution.
  13. Output:
  14. x: (B, C, H, W) upsampled output at full resolution.
  15. """
  16. def __init__(self, channels, feature_channels):
  17. super().__init__()
  18. self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
  19. self.bn1 = nn.BatchNorm2d(channels[1])
  20. self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
  21. self.bn2 = nn.BatchNorm2d(channels[2])
  22. self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
  23. self.bn3 = nn.BatchNorm2d(channels[3])
  24. self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
  25. self.relu = nn.ReLU(True)
  26. def forward(self, x4, x3, x2, x1, x0):
  27. x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
  28. x = torch.cat([x, x3], dim=1)
  29. x = self.conv1(x)
  30. x = self.bn1(x)
  31. x = self.relu(x)
  32. x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
  33. x = torch.cat([x, x2], dim=1)
  34. x = self.conv2(x)
  35. x = self.bn2(x)
  36. x = self.relu(x)
  37. x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
  38. x = torch.cat([x, x1], dim=1)
  39. x = self.conv3(x)
  40. x = self.bn3(x)
  41. x = self.relu(x)
  42. x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
  43. x = torch.cat([x, x0], dim=1)
  44. x = self.conv4(x)
  45. return x