decoder.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import torch
  2. from torch import Tensor
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from .onnx_helper import (
  6. CustomOnnxCropToMatchSizeOp
  7. )
  8. class RecurrentDecoder(nn.Module):
  9. def __init__(self, feature_channels, decoder_channels):
  10. super().__init__()
  11. self.avgpool = AvgPool()
  12. self.decode4 = BottleneckBlock(feature_channels[3])
  13. self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
  14. self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
  15. self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
  16. self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
  17. def forward(self, s0, f1, f2, f3, f4, r1, r2, r3, r4):
  18. s1, s2, s3 = self.avgpool(s0)
  19. x4, r4 = self.decode4(f4, r4)
  20. x3, r3 = self.decode3(x4, f3, s3, r3)
  21. x2, r2 = self.decode2(x3, f2, s2, r2)
  22. x1, r1 = self.decode1(x2, f1, s1, r1)
  23. x0 = self.decode0(x1, s0)
  24. return x0, r1, r2, r3, r4
  25. class AvgPool(nn.Module):
  26. def __init__(self):
  27. super().__init__()
  28. self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
  29. def forward_single_frame(self, s0):
  30. s1 = self.avgpool(s0)
  31. s2 = self.avgpool(s1)
  32. s3 = self.avgpool(s2)
  33. return s1, s2, s3
  34. def forward_time_series(self, s0):
  35. B, T = s0.shape[:2]
  36. s0 = s0.flatten(0, 1)
  37. s1, s2, s3 = self.forward_single_frame(s0)
  38. s1 = s1.unflatten(0, (B, T))
  39. s2 = s2.unflatten(0, (B, T))
  40. s3 = s3.unflatten(0, (B, T))
  41. return s1, s2, s3
  42. def forward(self, s0):
  43. if s0.ndim == 5:
  44. return self.forward_time_series(s0)
  45. else:
  46. return self.forward_single_frame(s0)
  47. class BottleneckBlock(nn.Module):
  48. def __init__(self, channels):
  49. super().__init__()
  50. self.channels = channels
  51. self.gru = ConvGRU(channels // 2)
  52. def forward(self, x, r):
  53. a, b = x.split(self.channels // 2, dim=-3)
  54. b, r = self.gru(b, r)
  55. x = torch.cat([a, b], dim=-3)
  56. return x, r
  57. class UpsamplingBlock(nn.Module):
  58. def __init__(self, in_channels, skip_channels, src_channels, out_channels):
  59. super().__init__()
  60. self.out_channels = out_channels
  61. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  62. self.conv = nn.Sequential(
  63. nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
  64. nn.BatchNorm2d(out_channels),
  65. nn.ReLU(True),
  66. )
  67. self.gru = ConvGRU(out_channels // 2)
  68. def forward_single_frame(self, x, f, s, r):
  69. x = self.upsample(x)
  70. if not torch.onnx.is_in_onnx_export():
  71. x = x[:, :, :s.size(2), :s.size(3)]
  72. else:
  73. x = CustomOnnxCropToMatchSizeOp.apply(x, s)
  74. x = torch.cat([x, f, s], dim=1)
  75. x = self.conv(x)
  76. a, b = x.split(self.out_channels // 2, dim=1)
  77. b, r = self.gru(b, r)
  78. x = torch.cat([a, b], dim=1)
  79. return x, r
  80. def forward_time_series(self, x, f, s, r):
  81. B, T, _, H, W = s.shape
  82. x = x.flatten(0, 1)
  83. f = f.flatten(0, 1)
  84. s = s.flatten(0, 1)
  85. x = self.upsample(x)
  86. if not torch.onnx.is_in_onnx_export():
  87. x = x[:, :, :H, :W]
  88. else:
  89. x = CustomOnnxCropToMatchSizeOp.apply(x, s)
  90. x = torch.cat([x, f, s], dim=1)
  91. x = self.conv(x)
  92. x = x.unflatten(0, (B, T))
  93. a, b = x.split(self.out_channels // 2, dim=2)
  94. b, r = self.gru(b, r)
  95. x = torch.cat([a, b], dim=2)
  96. return x, r
  97. def forward(self, x, f, s, r):
  98. if x.ndim == 5:
  99. return self.forward_time_series(x, f, s, r)
  100. else:
  101. return self.forward_single_frame(x, f, s, r)
  102. class OutputBlock(nn.Module):
  103. def __init__(self, in_channels, src_channels, out_channels):
  104. super().__init__()
  105. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  106. self.conv = nn.Sequential(
  107. nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
  108. nn.BatchNorm2d(out_channels),
  109. nn.ReLU(True),
  110. nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
  111. nn.BatchNorm2d(out_channels),
  112. nn.ReLU(True),
  113. )
  114. def forward_single_frame(self, x, s):
  115. x = self.upsample(x)
  116. if not torch.onnx.is_in_onnx_export():
  117. x = x[:, :, :s.size(1), :s.size(2)]
  118. else:
  119. x = CustomOnnxCropToMatchSizeOp.apply(x, s)
  120. x = torch.cat([x, s], dim=1)
  121. x = self.conv(x)
  122. return x
  123. def forward_time_series(self, x, s):
  124. B, T, _, H, W = s.shape
  125. x = x.flatten(0, 1)
  126. s = s.flatten(0, 1)
  127. x = self.upsample(x)
  128. x = x[:, :, :H, :W]
  129. x = torch.cat([x, s], dim=1)
  130. x = self.conv(x)
  131. x = x.unflatten(0, (B, T))
  132. return x
  133. def forward(self, x, s):
  134. if x.ndim == 5:
  135. return self.forward_time_series(x, s)
  136. else:
  137. return self.forward_single_frame(x, s)
  138. class ConvGRU(nn.Module):
  139. def __init__(self,
  140. channels: int,
  141. kernel_size: int = 3,
  142. padding: int = 1):
  143. super().__init__()
  144. self.channels = channels
  145. self.ih = nn.Sequential(
  146. nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
  147. nn.Sigmoid()
  148. )
  149. self.hh = nn.Sequential(
  150. nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
  151. nn.Tanh()
  152. )
  153. def forward_single_frame(self, x, h):
  154. r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
  155. c = self.hh(torch.cat([x, r * h], dim=1))
  156. h = (1 - z) * h + z * c
  157. return h, h
  158. def forward_time_series(self, x, h):
  159. o = []
  160. for xt in x.unbind(dim=1):
  161. ot, h = self.forward_single_frame(xt, h)
  162. o.append(ot)
  163. o = torch.stack(o, dim=1)
  164. return o, h
  165. def forward(self, x, h):
  166. h = h.expand_as(x)
  167. if x.ndim == 5:
  168. return self.forward_time_series(x, h)
  169. else:
  170. return self.forward_single_frame(x, h)
  171. class Projection(nn.Module):
  172. def __init__(self, in_channels, out_channels):
  173. super().__init__()
  174. self.conv = nn.Conv2d(in_channels, out_channels, 1)
  175. def forward_single_frame(self, x):
  176. return self.conv(x)
  177. def forward_time_series(self, x):
  178. B, T = x.shape[:2]
  179. return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
  180. def forward(self, x):
  181. if x.ndim == 5:
  182. return self.forward_time_series(x)
  183. else:
  184. return self.forward_single_frame(x)