Source code for spacetimeformer.spacetimeformer_model.nn.model

from functools import partial
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as pyd

from ..utils.masking import TriangularCausalMask, ProbMask
from .encoder import Encoder, EncoderLayer, VariableDownsample, Normalization
from .decoder import Decoder, DecoderLayer
from .attn import (
    FullAttention,
    ProbAttention,
    AttentionLayer,
    LocalAttentionLayer,
    PerformerAttention,
    BenchmarkAttention,
    NystromSelfAttention,
)
from .embed import SpacetimeformerEmbedding

warnings.filterwarnings("ignore", category=UserWarning)


[docs]class Spacetimeformer(nn.Module): def __init__( self, d_y: int = 1, d_x: int = 4, start_token_len: int = 64, attn_factor: int = 5, d_model: int = 512, n_heads: int = 8, e_layers: int = 2, d_layers: int = 2, d_ff: int = 512, time_emb_dim: int = 6, dropout_emb: float = 0.05, dropout_token: float = 0.05, dropout_attn_out: float = 0.05, dropout_ff: float = 0.05, dropout_qkv: float = 0.05, global_self_attn: str = "performer", local_self_attn: str = "none", global_cross_attn: str = "performer", local_cross_attn: str = "none", performer_attn_kernel: str = "relu", performer_redraw_interval: int = 250, embed_method: str = "spatio-temporal", activation: str = "gelu", post_norm: bool = True, norm: str = "layer", initial_downsample_convs: int = 0, intermediate_downsample_convs: int = 0, device=torch.device("cuda:0"), null_value: float = None, verbose: bool = True, ): super().__init__() if e_layers: assert intermediate_downsample_convs <= e_layers - 1 if embed_method == "temporal": assert ( local_self_attn == "none" ), "local attention not compatible with Temporal-only embedding" assert ( local_cross_attn == "none" ), "Local Attention not compatible with Temporal-only embedding" self.start_token_len = start_token_len self.embed_method = embed_method # Encoding self.embedding = SpacetimeformerEmbedding( d_y=d_y, d_x=d_x, d_model=d_model, time_emb_dim=time_emb_dim, downsample_convs=initial_downsample_convs, method=embed_method, start_token_len=start_token_len, null_value=null_value, ) # Select Attention Mechanisms attn_kwargs = { "d_model": d_model, "n_heads": n_heads, "dropout_qkv": dropout_qkv, "d_y": d_y, "dropout_attn_out": dropout_attn_out, "attn_factor": attn_factor, "performer_attn_kernel": performer_attn_kernel, "performer_redraw_interval": performer_redraw_interval, } GlobalSelfAttn = self._global_attn_switch(global_self_attn, **attn_kwargs) GlobalCrossAttn = self._global_attn_switch(global_cross_attn, **attn_kwargs) LocalSelfAttn = self._local_attn_switch(local_self_attn, **attn_kwargs) LocalCrossAttn = self._local_attn_switch(local_cross_attn, **attn_kwargs) self.encoder = Encoder( attn_layers=[ EncoderLayer( global_attention=GlobalSelfAttn(), local_attention=LocalSelfAttn(), d_model=d_model, d_ff=d_ff, dropout_ff=dropout_ff, activation=activation, post_norm=post_norm, norm=norm, ) for l in range(e_layers) ], conv_layers=[ VariableDownsample(d_y=d_y, d_model=d_model) for l in range(intermediate_downsample_convs) ], norm_layer=Normalization(method=norm, d_model=d_model) if not post_norm else None, emb_dropout=dropout_emb, data_dropout=dropout_token, ) # Decoder self.decoder = Decoder( layers=[ DecoderLayer( global_self_attention=GlobalSelfAttn(), local_self_attention=LocalSelfAttn(), global_cross_attention=GlobalCrossAttn(), local_cross_attention=LocalCrossAttn(), d_model=d_model, d_ff=d_ff, dropout_ff=dropout_ff, activation=activation, post_norm=post_norm, norm=norm, ) for l in range(d_layers) ], norm_layer=Normalization(method=norm, d_model=d_model) if not post_norm else None, emb_dropout=dropout_emb, data_dropout=dropout_token, ) qprint = lambda _msg_: print(_msg_) if verbose else None qprint(f"GlobalSelfAttn: {self.decoder.layers[0].global_self_attention}") qprint(f"GlobalCrossAttn: {self.decoder.layers[0].global_cross_attention}") qprint(f"LocalSelfAttn: {self.decoder.layers[0].local_self_attention}") qprint(f"LocalCrossAttn: {self.decoder.layers[0].local_cross_attention}") qprint(f"Using Embedding: {embed_method}") qprint(f"Time Emb Dim: {time_emb_dim}") qprint(f"Space Embedding: {self.embedding.SPACE}") qprint(f"Time Embedding: {self.embedding.TIME}") qprint(f"Val Embedding: {self.embedding.VAL}") qprint(f"Given Embedding: {self.embedding.GIVEN}") out_dim = 2 if self.embed_method == "spatio-temporal" else 2 * d_y self.forecaster = nn.Linear(d_model, out_dim, bias=True) self.classifier = nn.Linear(d_model, d_y, bias=True) self.d_y = d_y def _fold_spatio_temporal(self, dec_out): dec_out = dec_out.chunk(self.d_y, dim=1) means = [] log_stds = [] for y in dec_out: mean, log_std = y.chunk(2, dim=-1) means.append(mean) log_stds.append(log_std) means = torch.cat(means, dim=-1)[:, self.start_token_len :, :] log_stds = torch.cat(log_stds, dim=-1)[:, self.start_token_len :, :] return means, log_stds
[docs] def forward( self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None, output_attention=False, ): batch_size = x_enc.shape[0] enc_vt_emb, enc_s_emb, enc_var_idx = self.embedding( x_enc, x_mark_enc, is_encoder=True ) enc_out, attns = self.encoder( val_time_emb=enc_vt_emb, space_emb=enc_s_emb, attn_mask=enc_self_mask, output_attn=output_attention, ) dec_vt_emb, dec_s_emb, dec_var_idx = self.embedding( x_dec, x_mark_dec, is_encoder=False ) dec_out = self.decoder( val_time_emb=dec_vt_emb, space_emb=dec_s_emb, cross=enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, ) forecast_out = self.forecaster(dec_out) if self.embed_method == "spatio-temporal": means, log_stds = self._fold_spatio_temporal(forecast_out) else: forecast_out = forecast_out[:, self.start_token_len :, :] means, log_stds = forecast_out.chunk(2, dim=-1) # stabilization trick from Neural Processes papers stds = 1e-3 + (1.0 - 1e-3) * torch.log(1.0 + log_stds.exp()) pred_distrib = pyd.Normal(means, stds) if dec_var_idx is not None and enc_var_idx is not None: # note that detaching the input like this means the transformer layers # are not optimizing for classification accuracy (but the linear classifier # layer still is). This is just a test to see how much unique spatial info # remains in the output after all the global attention layers. classifier_dec_out = self.classifier(dec_out.detach()) classifier_enc_out = self.classifier(enc_out.detach()) classifier_out = torch.cat((classifier_enc_out, classifier_dec_out), dim=1) var_idxs = torch.cat((enc_var_idx, dec_var_idx), dim=1) else: classifier_out, var_idxs = None, None return pred_distrib, (classifier_out, var_idxs), attns
def _global_attn_switch( self, global_attn_str: str, d_model: int, n_heads: int, d_y: int, dropout_qkv: float, dropout_attn_out: float, attn_factor: int, performer_attn_kernel: str, performer_redraw_interval: int, ): if global_attn_str == "full": # standard full (n^2) attention Attn = partial( AttentionLayer, attention=partial(FullAttention, attention_dropout=dropout_attn_out), d_model=d_model, n_heads=n_heads, mix=False, dropout_qkv=dropout_qkv, ) elif global_attn_str == "prob": # Informer-style Prob self Full cross attention Attn = partial( AttentionLayer, attention=partial( ProbAttention, factor=attn_factor, attention_dropout=dropout_attn_out, ), d_model=d_model, n_heads=n_heads, mix=False, dropout_qkv=dropout_qkv, ) elif global_attn_str == "performer": # Performer Linear Attention Attn = partial( AttentionLayer, attention=partial( PerformerAttention, dim_heads=(d_model // n_heads), kernel=performer_attn_kernel, feature_redraw_interval=performer_redraw_interval, ), d_model=d_model, n_heads=n_heads, mix=False, dropout_qkv=dropout_qkv, ) elif global_attn_str == "nystromformer": Attn = partial( NystromSelfAttention, d_model=d_model, n_heads=n_heads, attention_dropout=dropout_attn_out, ) elif global_attn_str == "benchmark": Attn = BenchmarkAttention elif global_attn_str == "none": Attn = lambda: None else: raise ValueError(f"Unrecognized Global Attention '{global_attn_str}'") return Attn def _local_attn_switch( self, local_attn_str: str, d_y: int, d_model: int, n_heads: int, dropout_qkv: float, dropout_attn_out: float, attn_factor: int, performer_attn_kernel: str, performer_redraw_interval: int, ): if local_attn_str == "prob": # Prob Local Attention Attn = partial( LocalAttentionLayer, attention=partial( ProbAttention, factor=attn_factor, attention_dropout=dropout_attn_out, ), d_model=d_model, n_heads=n_heads, dropout_qkv=dropout_qkv, d_y=d_y, ) elif local_attn_str == "full": Attn = partial( LocalAttentionLayer, attention=partial(FullAttention, attention_dropout=dropout_attn_out), d_model=d_model, n_heads=n_heads, dropout_qkv=dropout_qkv, d_y=d_y, ) elif local_attn_str == "performer": # Performer Local Attention Attn = partial( LocalAttentionLayer, attention=partial( PerformerAttention, dim_heads=(d_model // n_heads), kernel=performer_attn_kernel, feature_redraw_interval=performer_redraw_interval, ), d_model=d_model, n_heads=n_heads, dropout_qkv=dropout_qkv, d_y=d_y, ) elif local_attn_str == "benchmark": Attn = BenchmarkAttention elif local_attn_str == "none": # Ablation of Local Attention Attn = lambda: None else: raise ValueError(f"Unrecognized Local Attention '{local_attn_str}'") return Attn