import random
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
import spacetimeformer as stf
[docs]class LSTM_Encoder(nn.Module):
def __init__(
self,
input_dim: int = 1,
hidden_dim: int = 256,
n_layers: int = 2,
dropout: float = 0.2,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.lstm = nn.LSTM(
input_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True
)
[docs] def forward(self, x_context: torch.Tensor):
outputs, (hidden, cell) = self.lstm(x_context)
return hidden, cell
[docs]class LSTM_Decoder(nn.Module):
def __init__(
self,
output_dim: int = 1,
input_dim: int = 1,
hidden_dim: int = 256,
n_layers: int = 2,
dropout: float = 0.2,
):
super().__init__()
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(
input_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True
)
self.fc = nn.Linear(hidden_dim, output_dim)
[docs] def forward(self, x_t, hidden, cell):
output, (hidden, cell) = self.lstm(x_t, (hidden, cell))
y_t1 = self.fc(output)
return y_t1, hidden, cell
[docs]class LSTM_Seq2Seq(nn.Module):
def __init__(self, t2v: stf.Time2Vec, encoder: LSTM_Encoder, decoder: LSTM_Decoder):
super().__init__()
self.t2v = t2v
self.encoder = encoder
self.decoder = decoder
def _merge(self, x, y):
return torch.cat((x, y), dim=-1)
[docs] def forward(
self,
x_context,
y_context,
x_target,
y_target,
teacher_forcing_prob,
):
if self.t2v is not None:
x_context = self.t2v(x_context)
x_target = self.t2v(x_target)
pred_len = y_target.shape[1]
batch_size = y_target.shape[0]
y_dim = y_target.shape[2]
outputs = -torch.ones(batch_size, pred_len, y_dim).to(y_target.device)
merged_context = self._merge(x_context, y_context)
hidden, cell = self.encoder(merged_context)
decoder_input = self._merge(x_context[:, -1], y_context[:, -1]).unsqueeze(1)
for t in range(0, pred_len):
output, hidden, cell = self.decoder(decoder_input, hidden, cell)
outputs[:, t] = output.squeeze(1)
decoder_y = (
y_target[:, t].unsqueeze(1)
if random.random() < teacher_forcing_prob
else output
)
decoder_input = self._merge(x_target[:, t].unsqueeze(1), decoder_y)
return outputs
[docs]class LSTM_Forecaster(stf.Forecaster):
def __init__(
self,
d_x: int = 6,
d_y: int = 1,
time_emb_dim: int = 0,
n_layers: int = 2,
hidden_dim: int = 32,
dropout_p: float = 0.2,
# training
learning_rate: float = 1e-3,
teacher_forcing_prob: float = 0.5,
l2_coeff: float = 0,
loss: str = "mse",
linear_window: int = 0,
):
super().__init__(
l2_coeff=l2_coeff,
learning_rate=learning_rate,
loss=loss,
linear_window=linear_window,
)
self.t2v = stf.Time2Vec(input_dim=d_x, embed_dim=time_emb_dim)
input_dim = (time_emb_dim if time_emb_dim > 0 else d_x) + d_y
self.encoder = LSTM_Encoder(
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout_p,
)
self.decoder = LSTM_Decoder(
output_dim=d_y,
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout_p,
)
self.model = LSTM_Seq2Seq(self.t2v, self.encoder, self.decoder).to(self.device)
self.teacher_forcing_prob = teacher_forcing_prob
@property
def train_step_forward_kwargs(self):
return {"force": self.teacher_forcing_prob}
@property
def eval_step_forward_kwargs(self):
return {"force": 0.0}
[docs] def forward_model_pass(self, x_c, y_c, x_t, y_t, force=None):
assert force is not None
preds = self.model.forward(x_c, y_c, x_t, y_t, teacher_forcing_prob=force)
return (preds,)
[docs] @classmethod
def add_cli(self, parser):
super().add_cli(parser)
parser.add_argument(
"--hidden_dim",
type=int,
default=128,
help="Hidden dimension for LSTM network.",
)
parser.add_argument(
"--n_layers",
type=int,
default=2,
help="Number of stacked LSTM layers",
)
parser.add_argument(
"--dropout_p",
type=float,
default=0.3,
help="Dropout fraction for LSTM.",
)
parser.add_argument(
"--time_emb_dim",
type=int,
default=12,
help="Embedding dimension for Tim2Vec encoding. Set to zero to disable T2V.",
)