from typing import List
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
import spacetimeformer as stf
try:
from torch_geometric_temporal.nn import MTGNN
except ImportError:
[docs] class MTGNN:
def __init__(self, *args, **kwargs):
raise ImportError(
"\t Missing `torch_geometric_temporal` package required to use MTGNN\n\
model. This is optional for all other model types and not installed\n\
with `pip install -r requirements.txt` because of CUDA versioning issues.\n\
Please see https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/docs/source/notes/installation.rst\n\
for installation instructions."
)
[docs]class MTGNN_Forecaster(stf.Forecaster):
def __init__(
self,
d_y: int,
d_x: int,
context_points: int,
target_points: int,
use_gcn_layer: bool = True,
adaptive_adj_mat: bool = True,
gcn_depth: int = 2,
dropout_p: float = 0.2,
node_dim: int = 40,
dilation_exponential: int = 1,
conv_channels: int = 32,
subgraph_size: int = 8,
skip_channels: int = 64,
end_channels: int = 128,
residual_channels: int = 32,
layers: int = 3,
propalpha: float = 0.05,
tanhalpha: float = 3,
kernel_set: List[int] = [2, 3, 6, 7],
kernel_size: int = 7,
learning_rate: float = 1e-3,
l2_coeff: float = 0,
time_emb_dim: int = 0,
loss: str = "mae",
linear_window: int = 0,
):
super().__init__(
l2_coeff=l2_coeff,
learning_rate=learning_rate,
loss=loss,
linear_window=linear_window,
)
subgraph_size = min(subgraph_size, d_y)
self.learning_rate = learning_rate
self.time2vec = stf.Time2Vec(input_dim=d_x, embed_dim=time_emb_dim)
self.model = MTGNN(
gcn_true=use_gcn_layer,
build_adj=adaptive_adj_mat,
gcn_depth=gcn_depth,
num_nodes=d_y,
kernel_set=kernel_set,
kernel_size=kernel_size,
dropout=dropout_p,
subgraph_size=subgraph_size,
node_dim=node_dim,
conv_channels=conv_channels,
residual_channels=residual_channels,
skip_channels=skip_channels,
end_channels=end_channels,
seq_length=context_points,
in_dim=d_x + 1 if time_emb_dim == 0 else time_emb_dim + 1,
out_dim=target_points,
layers=layers,
propalpha=propalpha,
tanhalpha=tanhalpha,
dilation_exponential=dilation_exponential,
layer_norm_affline=True,
)
self.d_y = d_y
@property
def eval_step_forward_kwargs(self):
return {}
@property
def train_step_forward_kwargs(self):
return {}
[docs] def forward_model_pass(self, x_c, y_c, x_t, y_t):
x_c = self.time2vec(x_c)
pred_len = y_t.shape[-2]
output = torch.zeros_like(y_t).to(y_t.device)
# y_c = (batch, len, nodes) --> (batch, 1, nodes, len)
y_c = y_c.transpose(-1, 1).unsqueeze(1)
# x_c = (batch, len, d_x) --> (batch, d_x, nodes, len)
x_c = x_c.transpose(-1, 1).unsqueeze(-2).repeat(1, 1, self.d_y, 1)
ctxt = torch.cat((x_c, y_c), dim=1)
output = self.model.forward(ctxt).squeeze(-1)
return (output,)
[docs] @classmethod
def add_cli(self, parser):
super().add_cli(parser)
parser.add_argument("--gcn_depth", type=int, default=2)
parser.add_argument("--dropout_p", type=float, default=0.3)
parser.add_argument("--node_dim", type=int, default=40)
parser.add_argument("--dilation_exponential", type=int, default=1)
parser.add_argument("--conv_channels", type=int, default=32)
parser.add_argument("--subgraph_size", type=int, default=20)
parser.add_argument("--skip_channels", type=int, default=64)
parser.add_argument("--end_channels", type=int, default=128)
parser.add_argument("--residual_channels", type=int, default=32)
parser.add_argument("--layers", type=int, default=3)
parser.add_argument("--propalpha", type=float, default=0.05)
parser.add_argument("--tanhalpha", type=float, default=3.0)
parser.add_argument("--kernel_size", type=int, default=7)
parser.add_argument("--time_emb_dim", type=int, default=12)