| | |
| |
|
| | from typing import Any, Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import Tensor, nn, sigmoid, tanh |
| |
|
| |
|
| | class ConvGate(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | hidden_channels: int, |
| | kernel_size: Union[Tuple[int, int], int], |
| | padding: Union[Tuple[int, int], int], |
| | stride: Union[Tuple[int, int], int], |
| | bias: bool, |
| | ): |
| | super(ConvGate, self).__init__() |
| | self.conv_x = nn.Conv2d( |
| | in_channels=in_channels, |
| | out_channels=hidden_channels * 4, |
| | kernel_size=kernel_size, |
| | padding=padding, |
| | stride=stride, |
| | bias=bias, |
| | ) |
| | self.conv_h = nn.Conv2d( |
| | in_channels=hidden_channels, |
| | out_channels=hidden_channels * 4, |
| | kernel_size=kernel_size, |
| | padding=padding, |
| | stride=stride, |
| | bias=bias, |
| | ) |
| | self.bn2d = nn.BatchNorm2d(hidden_channels * 4) |
| |
|
| | def forward(self, x, hidden_state): |
| | gated = self.conv_x(x) + self.conv_h(hidden_state) |
| | return self.bn2d(gated) |
| |
|
| |
|
| | class ConvLSTMCell(nn.Module): |
| | def __init__( |
| | self, in_channels, hidden_channels, kernel_size, padding, stride, bias |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.gates = nn.ModuleList( |
| | [ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)] |
| | ) |
| |
|
| | def forward( |
| | self, x: Tensor, hidden_state: Tensor, cell_state: Tensor |
| | ) -> Tuple[Tensor, Tensor]: |
| | gated = self.gates[0](x, hidden_state) |
| | i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1) |
| |
|
| | i_gated = sigmoid(i_gated) |
| | f_gated = sigmoid(f_gated) |
| | o_gated = sigmoid(o_gated) |
| |
|
| | cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated)) |
| | hidden_state = o_gated.mul(tanh(cell_state)) |
| |
|
| | return hidden_state, cell_state |
| |
|
| |
|
| | class ConvLSTM(nn.Module): |
| | """ConvLSTM module""" |
| |
|
| | def __init__( |
| | self, |
| | in_channels, |
| | hidden_channels, |
| | kernel_size, |
| | padding, |
| | stride, |
| | bias, |
| | batch_first, |
| | bidirectional, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.hidden_channels = hidden_channels |
| | self.bidirectional = bidirectional |
| | self.batch_first = batch_first |
| |
|
| | |
| | |
| | self.conv_lstm_cells = nn.ModuleList( |
| | [ |
| | ConvLSTMCell( |
| | in_channels, hidden_channels, kernel_size, padding, stride, bias |
| | ) |
| | ] |
| | ) |
| |
|
| | if self.bidirectional: |
| | self.conv_lstm_cells.append( |
| | ConvLSTMCell( |
| | in_channels, hidden_channels, kernel_size, padding, stride, bias |
| | ) |
| | ) |
| |
|
| | self.batch_size = None |
| | self.seq_len = None |
| | self.height = None |
| | self.width = None |
| |
|
| | def forward( |
| | self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None |
| | ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| | |
| | x = self._check_shape(x) |
| | hidden_state, cell_state, backward_hidden_state, backward_cell_state = ( |
| | self.init_state(x, state) |
| | ) |
| |
|
| | output, hidden_state, cell_state = self._forward( |
| | self.conv_lstm_cells[0], x, hidden_state, cell_state |
| | ) |
| |
|
| | if self.bidirectional: |
| | x = torch.flip(x, [1]) |
| | backward_output, backward_hidden_state, backward_cell_state = self._forward( |
| | self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state |
| | ) |
| |
|
| | output = torch.cat([output, backward_output], dim=-3) |
| | hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1) |
| | cell_state = torch.cat([cell_state, backward_cell_state], dim=-1) |
| | return output, (hidden_state, cell_state) |
| |
|
| | def _forward(self, lstm_cell, x, hidden_state, cell_state): |
| | outputs = [] |
| | for time_step in range(self.seq_len): |
| | x_t = x[:, time_step, :, :, :] |
| | hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state) |
| | outputs.append(hidden_state.detach()) |
| | output = torch.stack(outputs, dim=1) |
| | return output, hidden_state, cell_state |
| |
|
| | def _check_shape(self, x: Tensor) -> Tensor: |
| | if self.batch_first: |
| | batch_size, self.seq_len = x.shape[0], x.shape[1] |
| | else: |
| | batch_size, self.seq_len = x.shape[1], x.shape[0] |
| | x = x.permute(1, 0, 2, 3) |
| | x = torch.swapaxes(x, 0, 1) |
| |
|
| | self.height = x.shape[-2] |
| | self.width = x.shape[-1] |
| |
|
| | dim = len(x.shape) |
| |
|
| | if dim == 4: |
| | x = x.unsqueeze(dim=1) |
| | x = x.view(batch_size, self.seq_len, -1, self.height, self.width) |
| | x = x.contiguous() |
| | elif dim <= 3: |
| | raise ValueError( |
| | f"Got {len(x.shape)} dimensional tensor. Input shape unmatched" |
| | ) |
| |
|
| | return x |
| |
|
| | def init_state( |
| | self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] |
| | ) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]: |
| | |
| | backward_hidden_state, backward_cell_state = None, None |
| |
|
| | if state is None: |
| | self.batch_size = x.shape[0] |
| | hidden_state, cell_state = self._init_state(x.dtype, x.device) |
| |
|
| | if self.bidirectional: |
| | backward_hidden_state, backward_cell_state = self._init_state( |
| | x.dtype, x.device |
| | ) |
| | else: |
| | if self.bidirectional: |
| | hidden_state, hidden_state_back = state[0].chunk(2, dim=-1) |
| | cell_state, cell_state_back = state[1].chunk(2, dim=-1) |
| | else: |
| | hidden_state, cell_state = state |
| |
|
| | return hidden_state, cell_state, backward_hidden_state, backward_cell_state |
| |
|
| | def _init_state(self, dtype, device): |
| | self.register_buffer( |
| | "hidden_state", |
| | torch.zeros( |
| | (1, self.hidden_channels, self.height, self.width), |
| | dtype=dtype, |
| | device=device, |
| | ), |
| | ) |
| | self.register_buffer( |
| | "cell_state", |
| | torch.zeros( |
| | (1, self.hidden_channels, self.height, self.width), |
| | dtype=dtype, |
| | device=device, |
| | ), |
| | ) |
| | return self.hidden_state, self.cell_state |
| |
|