Source code for spacetimeformer.spacetimeformer_model.spacetimeformer_model
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchmetrics
import spacetimeformer as stf
[docs]class Spacetimeformer_Forecaster(stf.Forecaster):
def __init__(
self,
d_y: int = 1,
d_x: int = 4,
start_token_len: int = 64,
attn_factor: int = 5,
d_model: int = 512,
n_heads: int = 8,
e_layers: int = 2,
d_layers: int = 2,
d_ff: int = 2048,
dropout_emb: float = 0.05,
dropout_token: float = 0.05,
dropout_qkv: float = 0.05,
dropout_ff: float = 0.05,
dropout_attn_out: float = 0.05,
global_self_attn: str = "performer",
local_self_attn: str = "none",
global_cross_attn: str = "performer",
local_cross_attn: str = "none",
performer_kernel: str = "relu",
embed_method: str = "spatio-temporal",
performer_relu: bool = True,
performer_redraw_interval: int = 1000,
activation: str = "gelu",
post_norm: bool = False,
norm: str = "layer",
init_lr: float = 1e-10,
base_lr: float = 3e-4,
warmup_steps: float = 0,
decay_factor: float = 0.25,
initial_downsample_convs: int = 0,
intermediate_downsample_convs: int = 0,
l2_coeff: float = 0,
loss: str = "nll",
linear_window: int = 0,
class_loss_imp: float = 0.1,
time_emb_dim: int = 6,
null_value: float = None,
verbose=True,
):
super().__init__(l2_coeff=l2_coeff, loss=loss, linear_window=linear_window)
self.spacetimeformer = stf.spacetimeformer_model.nn.Spacetimeformer(
d_y=d_y,
d_x=d_x,
start_token_len=start_token_len,
attn_factor=attn_factor,
d_model=d_model,
n_heads=n_heads,
e_layers=e_layers,
d_layers=d_layers,
d_ff=d_ff,
initial_downsample_convs=initial_downsample_convs,
intermediate_downsample_convs=intermediate_downsample_convs,
dropout_emb=dropout_emb,
dropout_attn_out=dropout_attn_out,
dropout_qkv=dropout_qkv,
dropout_ff=dropout_ff,
dropout_token=dropout_token,
global_self_attn=global_self_attn,
local_self_attn=local_self_attn,
global_cross_attn=global_cross_attn,
local_cross_attn=local_cross_attn,
activation=activation,
post_norm=post_norm,
device=self.device,
norm=norm,
embed_method=embed_method,
performer_attn_kernel=performer_kernel,
performer_redraw_interval=performer_redraw_interval,
time_emb_dim=time_emb_dim,
verbose=True,
null_value=null_value,
)
self.start_token_len = start_token_len
self.init_lr = init_lr
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.decay_factor = decay_factor
self.embed_method = embed_method
self.class_loss_imp = class_loss_imp
self.set_null_value(null_value)
qprint = lambda _msg_: print(_msg_) if verbose else None
qprint(f" *** Spacetimeformer Summary: *** ")
qprint(f"\tModel Dim: {d_model}")
qprint(f"\tFF Dim: {d_ff}")
qprint(f"\tEnc Layers: {e_layers}")
qprint(f"\tDec Layers: {d_layers}")
qprint(f"\tEmbed Dropout: {dropout_emb}")
qprint(f"\tToken Dropout: {dropout_token}")
qprint(f"\tFF Dropout: {dropout_ff}")
qprint(f"\tAttn Out Dropout: {dropout_attn_out}")
qprint(f"\tQKV Dropout: {dropout_qkv}")
qprint(f"\tL2 Coeff: {l2_coeff}")
qprint(f"\tWarmup Steps: {warmup_steps}")
qprint(f"\tNormalization Scheme: {norm}")
qprint(f" *** *** ")
@property
def train_step_forward_kwargs(self):
return {"output_attn": False}
@property
def eval_step_forward_kwargs(self):
return {"output_attn": False}
[docs] def step(self, batch: Tuple[torch.Tensor], train: bool):
kwargs = (
self.train_step_forward_kwargs if train else self.eval_step_forward_kwargs
)
time_mask = self.time_masked_idx if train else None
forecast_loss, class_loss, acc, output, mask = self.compute_loss(
batch=batch,
time_mask=time_mask,
forward_kwargs=kwargs,
)
*_, y_t = batch
stats = self._compute_stats(mask * output, mask * y_t)
stats["forecast_loss"] = forecast_loss
stats["class_loss"] = class_loss
stats["loss"] = forecast_loss + self.class_loss_imp * class_loss
stats["acc"] = acc
"""
# temporary traffic stats:
preds = self._inv_scaler(output.detach().cpu().numpy())
true = self._inv_scaler(y_t.detach().cpu().numpy())
mask = mask.detach().cpu().numpy()
time_based_mae = abs((mask * preds) - (mask * true)).mean((0, -1))
for time_idx in range(len(time_based_mae)):
stats[f"mae_traffic_time_{time_idx}"] = time_based_mae[time_idx]
"""
return stats
[docs] def classification_loss(
self, logits: torch.Tensor, labels: torch.Tensor
) -> Tuple[torch.Tensor]:
labels = labels.view(-1).to(logits.device)
d_y = labels.max() + 1
logits = logits.view(
-1, d_y
) # = torch.cat(logits.chunk(bs, dim=0), dim=1).squeeze(0)
class_loss = F.cross_entropy(logits, labels)
acc = torchmetrics.functional.accuracy(
torch.softmax(logits, dim=1),
labels,
)
return class_loss, acc
[docs] def compute_loss(self, batch, time_mask=None, forward_kwargs={}):
x_c, y_c, x_t, y_t = batch
outputs, (logits, labels) = self(x_c, y_c, x_t, y_t, **forward_kwargs)
forecast_loss, mask = self.forecasting_loss(
outputs=outputs, y_t=y_t, time_mask=time_mask
)
if self.embed_method == "spatio-temporal" and self.class_loss_imp > 0:
class_loss, acc = self.classification_loss(logits=logits, labels=labels)
else:
class_loss, acc = 0.0, -1.0
return forecast_loss, class_loss, acc, outputs.mean, mask
[docs] def forward_model_pass(self, x_c, y_c, x_t, y_t, output_attn=False):
if len(y_c.shape) == 2:
y_c = y_c.unsqueeze(-1)
y_t = y_t.unsqueeze(-1)
batch_x = y_c
batch_x_mark = x_c
if self.start_token_len > 0:
batch_y = torch.cat((y_c[:, -self.start_token_len :, :], y_t), dim=1)
batch_y_mark = torch.cat((x_c[:, -self.start_token_len :, :], x_t), dim=1)
else:
batch_y = y_t
batch_y_mark = x_t
dec_inp = torch.cat(
[
batch_y[:, : self.start_token_len, :],
torch.zeros((batch_y.shape[0], y_t.shape[1], batch_y.shape[-1])).to(
self.device
),
],
dim=1,
).float()
output, (logits, labels), attn = self.spacetimeformer(
x_enc=batch_x,
x_mark_enc=batch_x_mark,
x_dec=dec_inp,
x_mark_dec=batch_y_mark,
output_attention=output_attn,
)
if output_attn:
return output, (logits, labels), attn
return output, (logits, labels)
[docs] def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.base_lr, weight_decay=self.l2_coeff,
)
scheduler = stf.lr_scheduler.WarmupReduceLROnPlateau(
optimizer,
init_lr=self.init_lr,
peak_lr=self.base_lr,
warmup_steps=self.warmup_steps,
patience=2,
factor=self.decay_factor,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "val/forecast_loss",
"reduce_on_plateau": True,
},
}
[docs] @classmethod
def add_cli(self, parser):
super().add_cli(parser)
parser.add_argument(
"--start_token_len",
type=int,
required=True,
help="Length of decoder start token. Adds this many of the final context points to the start of the target sequence.",
)
parser.add_argument(
"--d_model", type=int, default=256, help="Transformer embedding dimension."
)
parser.add_argument(
"--n_heads", type=int, default=8, help="Number of self-attention heads."
)
parser.add_argument(
"--enc_layers", type=int, default=4, help="Transformer encoder layers."
)
parser.add_argument(
"--dec_layers", type=int, default=3, help="Transformer decoder layers."
)
parser.add_argument(
"--d_ff",
type=int,
default=1024,
help="Dimension of Transformer up-scaling MLP layer. (often 4 * d_model)",
)
parser.add_argument(
"--attn_factor",
type=int,
default=5,
help="ProbSparse attention factor. N/A to other attn mechanisms.",
)
parser.add_argument(
"--dropout_emb",
type=float,
default=0.2,
help="Embedding dropout rate. Drop out elements of the embedding vectors during training.",
)
parser.add_argument(
"--dropout_token",
type=float,
default=0.0,
help="Token dropout rate. Drop out entire input tokens during training.",
)
parser.add_argument(
"--dropout_attn_out",
type=float,
default=0.0,
help="Attention dropout rate. Dropout elements of the attention matrix. Only applicable to attn mechanisms that explicitly compute the attn matrix (e.g. Full).",
)
parser.add_argument(
"--dropout_qkv",
type=float,
default=0.0,
help="Query, Key and Value dropout rate. Dropout elements of these attention vectors during training.",
)
parser.add_argument(
"--dropout_ff",
type=float,
default=0.3,
help="Standard dropout applied to activations of FF networks in the Transformer.",
)
parser.add_argument(
"--global_self_attn",
type=str,
default="performer",
choices=[
"full",
"prob",
"performer",
"nystromformer",
"benchmark",
"none",
],
help="Attention mechanism type.",
)
parser.add_argument(
"--global_cross_attn",
type=str,
default="performer",
choices=[
"full",
"performer",
"benchmark",
"none",
],
help="Attention mechanism type.",
)
parser.add_argument(
"--local_self_attn",
type=str,
default="performer",
choices=[
"full",
"prob",
"performer",
"benchmark",
"none",
],
help="Attention mechanism type.",
)
parser.add_argument(
"--local_cross_attn",
type=str,
default="performer",
choices=[
"full",
"performer",
"benchmark",
"none",
],
help="Attention mechanism type.",
)
parser.add_argument(
"--activation",
type=str,
default="gelu",
choices=["relu", "gelu"],
help="Activation function for Transformer encoder and decoder layers.",
)
parser.add_argument(
"--post_norm",
action="store_true",
help="Enable post-norm architecture for Transformers. See https://arxiv.org/abs/2002.04745.",
)
parser.add_argument(
"--norm",
type=str,
choices=["layer", "batch", "scale", "power", "none"],
default="batch",
)
parser.add_argument(
"--init_lr", type=float, default=1e-10, help="Initial learning rate."
)
parser.add_argument(
"--base_lr",
type=float,
default=5e-4,
help="Base/peak LR. The LR is annealed to this value from --init_lr over --warmup_steps training steps.",
)
parser.add_argument(
"--warmup_steps", type=int, default=0, help="LR anneal steps."
)
parser.add_argument(
"--decay_factor",
type=float,
default=0.25,
help="Factor to reduce LR on plateau (after warmup period is over).",
)
parser.add_argument(
"--initial_downsample_convs",
type=int,
default=0,
help="Add downsampling Conv1Ds to the encoder embedding layer to reduce context sequence length.",
)
parser.add_argument(
"--class_loss_imp",
type=float,
default=0.1,
help="Coefficient for node classification loss function. Set to 0 to disable this feature. Does not significantly impact forecasting results due to detached gradient.",
)
parser.add_argument(
"--intermediate_downsample_convs",
type=int,
default=0,
help="Add downsampling Conv1Ds between encoder layers.",
)
parser.add_argument(
"--time_emb_dim",
type=int,
default=12,
help="Time embedding dimension. Embed *each dimension of x* with this many learned periodic values.",
)
parser.add_argument(
"--performer_kernel",
type=str,
default="relu",
choices=["softmax", "relu"],
help="Performer attention kernel. See Performer paper for details.",
)
parser.add_argument(
"--performer_redraw_interval",
type=int,
default=125,
help="Training steps between resampling orthogonal random features for FAVOR+ attention",
)
parser.add_argument(
"--embed_method",
type=str,
choices=["spatio-temporal", "temporal"],
default="spatio-temporal",
help="Embedding method. spatio-temporal enables long-sequence spatio-temporal transformer mode while temporal recovers default architecture.",
)