decoder.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import torch
  2. from torch import Tensor
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from typing import Tuple, Optional
  6. class RecurrentDecoder(nn.Module):
  7. def __init__(self, feature_channels, decoder_channels):
  8. super().__init__()
  9. self.avgpool = AvgPool()
  10. self.decode4 = BottleneckBlock(feature_channels[3])
  11. self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
  12. self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
  13. self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
  14. self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
  15. def forward(self,
  16. s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
  17. r1: Optional[Tensor], r2: Optional[Tensor],
  18. r3: Optional[Tensor], r4: Optional[Tensor]):
  19. s1, s2, s3 = self.avgpool(s0)
  20. x4, r4 = self.decode4(f4, r4)
  21. x3, r3 = self.decode3(x4, f3, s3, r3)
  22. x2, r2 = self.decode2(x3, f2, s2, r2)
  23. x1, r1 = self.decode1(x2, f1, s1, r1)
  24. x0 = self.decode0(x1, s0)
  25. return x0, r1, r2, r3, r4
  26. class AvgPool(nn.Module):
  27. def __init__(self):
  28. super().__init__()
  29. self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True)
  30. def forward_single_frame(self, s0):
  31. s1 = self.avgpool(s0)
  32. s2 = self.avgpool(s1)
  33. s3 = self.avgpool(s2)
  34. return s1, s2, s3
  35. def forward_time_series(self, s0):
  36. B, T = s0.shape[:2]
  37. s0 = s0.flatten(0, 1)
  38. s1, s2, s3 = self.forward_single_frame(s0)
  39. s1 = s1.unflatten(0, (B, T))
  40. s2 = s2.unflatten(0, (B, T))
  41. s3 = s3.unflatten(0, (B, T))
  42. return s1, s2, s3
  43. def forward(self, s0):
  44. if s0.ndim == 5:
  45. return self.forward_time_series(s0)
  46. else:
  47. return self.forward_single_frame(s0)
  48. class BottleneckBlock(nn.Module):
  49. def __init__(self, channels):
  50. super().__init__()
  51. self.channels = channels
  52. self.gru = ConvGRU(channels // 2)
  53. def forward(self, x, r: Optional[Tensor]):
  54. a, b = x.split(self.channels // 2, dim=-3)
  55. b, r = self.gru(b, r)
  56. x = torch.cat([a, b], dim=-3)
  57. return x, r
  58. class UpsamplingBlock(nn.Module):
  59. def __init__(self, in_channels, skip_channels, src_channels, out_channels):
  60. super().__init__()
  61. self.out_channels = out_channels
  62. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  63. self.conv = nn.Sequential(
  64. nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False),
  65. nn.BatchNorm2d(out_channels),
  66. nn.ReLU(True),
  67. )
  68. self.gru = ConvGRU(out_channels // 2)
  69. def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
  70. x = self.upsample(x)
  71. # Optimized for CoreML export
  72. if x.size(2) != s.size(2):
  73. x = x[:, :, :s.size(2), :]
  74. if x.size(3) != s.size(3):
  75. x = x[:, :, :, :s.size(3)]
  76. x = torch.cat([x, f, s], dim=1)
  77. x = self.conv(x)
  78. a, b = x.split(self.out_channels // 2, dim=1)
  79. b, r = self.gru(b, r)
  80. x = torch.cat([a, b], dim=1)
  81. return x, r
  82. def forward_time_series(self, x, f, s, r: Optional[Tensor]):
  83. B, T, _, H, W = s.shape
  84. x = x.flatten(0, 1)
  85. f = f.flatten(0, 1)
  86. s = s.flatten(0, 1)
  87. x = self.upsample(x)
  88. x = x[:, :, :H, :W]
  89. x = torch.cat([x, f, s], dim=1)
  90. x = self.conv(x)
  91. x = x.unflatten(0, (B, T))
  92. a, b = x.split(self.out_channels // 2, dim=2)
  93. b, r = self.gru(b, r)
  94. x = torch.cat([a, b], dim=2)
  95. return x, r
  96. def forward(self, x, f, s, r: Optional[Tensor]):
  97. if x.ndim == 5:
  98. return self.forward_time_series(x, f, s, r)
  99. else:
  100. return self.forward_single_frame(x, f, s, r)
  101. class OutputBlock(nn.Module):
  102. def __init__(self, in_channels, src_channels, out_channels):
  103. super().__init__()
  104. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  105. self.conv = nn.Sequential(
  106. nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False),
  107. nn.BatchNorm2d(out_channels),
  108. nn.ReLU(True),
  109. nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
  110. nn.BatchNorm2d(out_channels),
  111. nn.ReLU(True),
  112. )
  113. def forward_single_frame(self, x, s):
  114. x = self.upsample(x)
  115. if x.size(2) != s.size(2):
  116. x = x[:, :, :s.size(2), :]
  117. if x.size(3) != s.size(3):
  118. x = x[:, :, :, :s.size(3)]
  119. x = torch.cat([x, s], dim=1)
  120. x = self.conv(x)
  121. return x
  122. def forward_time_series(self, x, s):
  123. B, T, _, H, W = s.shape
  124. x = x.flatten(0, 1)
  125. s = s.flatten(0, 1)
  126. x = self.upsample(x)
  127. x = x[:, :, :H, :W]
  128. x = torch.cat([x, s], dim=1)
  129. x = self.conv(x)
  130. x = x.unflatten(0, (B, T))
  131. return x
  132. def forward(self, x, s):
  133. if x.ndim == 5:
  134. return self.forward_time_series(x, s)
  135. else:
  136. return self.forward_single_frame(x, s)
  137. class ConvGRU(nn.Module):
  138. def __init__(self,
  139. channels: int,
  140. kernel_size: int = 3,
  141. padding: int = 1):
  142. super().__init__()
  143. self.channels = channels
  144. self.ih = nn.Sequential(
  145. nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding),
  146. nn.Sigmoid()
  147. )
  148. self.hh = nn.Sequential(
  149. nn.Conv2d(channels * 2, channels, kernel_size, padding=padding),
  150. nn.Tanh()
  151. )
  152. def forward_single_frame(self, x, h):
  153. r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1)
  154. c = self.hh(torch.cat([x, r * h], dim=1))
  155. h = (1 - z) * h + z * c
  156. return h, h
  157. def forward_time_series(self, x, h):
  158. o = []
  159. for xt in x.unbind(dim=1):
  160. ot, h = self.forward_single_frame(xt, h)
  161. o.append(ot)
  162. o = torch.stack(o, dim=1)
  163. return o, h
  164. def forward(self, x, h: Optional[Tensor]):
  165. if h is None:
  166. h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)),
  167. device=x.device, dtype=x.dtype)
  168. if x.ndim == 5:
  169. return self.forward_time_series(x, h)
  170. else:
  171. return self.forward_single_frame(x, h)
  172. class Projection(nn.Module):
  173. def __init__(self, in_channels, out_channels):
  174. super().__init__()
  175. self.conv = nn.Conv2d(in_channels, out_channels, 1)
  176. def forward_single_frame(self, x):
  177. return self.conv(x)
  178. def forward_time_series(self, x):
  179. B, T = x.shape[:2]
  180. return self.conv(x.flatten(0, 1)).unflatten(0, (B, T))
  181. def forward(self, x):
  182. if x.ndim == 5:
  183. return self.forward_time_series(x)
  184. else:
  185. return self.forward_single_frame(x)