Source code for spacetimeformer.spacetimeformer_model.nn.embed
import torch
import torch.nn as nn
import torch.nn.functional as F
import spacetimeformer as stf
from .encoder import VariableDownsample
[docs]class SpacetimeformerEmbedding(nn.Module):
def __init__(
self,
d_y,
d_x,
d_model=256,
time_emb_dim=6,
method="spatio-temporal",
downsample_convs=1,
start_token_len=0,
null_value=None,
):
super().__init__()
assert method in ["spatio-temporal", "temporal"]
self.method = method
# account for added local position indicator "relative time"
d_x += 1
self.x_emb = stf.Time2Vec(d_x, embed_dim=time_emb_dim * d_x)
if self.method == "temporal":
y_emb_inp_dim = d_y + (time_emb_dim * d_x)
else:
y_emb_inp_dim = 1 + (time_emb_dim * d_x)
self.y_emb = nn.Linear(y_emb_inp_dim, d_model)
if self.method == "spatio-temporal":
self.var_emb = nn.Embedding(num_embeddings=d_y, embedding_dim=d_model)
self.start_token_len = start_token_len
self.given_emb = nn.Embedding(num_embeddings=2, embedding_dim=d_model)
self.downsize_convs = nn.ModuleList(
[VariableDownsample(d_y, d_model) for _ in range(downsample_convs)]
)
self._benchmark_embed_enc = None
self._benchmark_embed_dec = None
self.d_model = d_model
self.null_value = null_value
def __call__(self, y, x, is_encoder=True):
if self.method == "spatio-temporal":
val_time_emb, space_emb, var_idxs = self.spatio_temporal_embed(
y, x, is_encoder
)
else:
val_time_emb, space_emb = self.temporal_embed(y, x, is_encoder)
var_idxs = None
return val_time_emb, space_emb, var_idxs
[docs] def temporal_embed(self, y, x, is_encoder=True):
bs, length, d_y = y.shape
local_pos = (
torch.arange(length).view(1, -1, 1).repeat(bs, 1, 1).to(x.device) / length
)
if not self.TIME:
x = torch.zeros_like(x)
x = torch.cat((x, local_pos), dim=-1)
t2v_emb = self.x_emb(x)
emb_inp = torch.cat((y, t2v_emb), dim=-1)
emb = self.y_emb(emb_inp)
# "given" embedding
given = torch.ones((bs, length)).long().to(x.device)
if not is_encoder and self.GIVEN:
given[:, self.start_token_len :] = 0
given_emb = self.given_emb(given)
emb += given_emb
if is_encoder:
# shorten the sequence
for i, conv in enumerate(self.downsize_convs):
emb = conv(emb)
return emb, torch.zeros_like(emb)
SPACE = True
TIME = True
VAL = True
GIVEN = True
[docs] def spatio_temporal_embed(self, y, x, is_encoder=True):
bs, length, d_y = y.shape
# val + time embedding
y = torch.cat(y.chunk(d_y, dim=-1), dim=1)
local_pos = (
torch.arange(length).view(1, -1, 1).repeat(bs, 1, 1).to(x.device) / length
)
x = torch.cat((x, local_pos), dim=-1)
if not self.TIME:
x = torch.zeros_like(x)
if not self.VAL:
y = torch.zeros_like(y)
t2v_emb = self.x_emb(x).repeat(1, d_y, 1)
val_time_inp = torch.cat((y, t2v_emb), dim=-1)
val_time_emb = self.y_emb(val_time_inp)
# "given" embedding
if self.GIVEN:
given = torch.ones((bs, length, d_y)).long().to(x.device) # start as T
if not is_encoder:
# mask missing values that need prediction...
given[:, self.start_token_len :, :] = 0
given = torch.cat(given.chunk(d_y, dim=-1), dim=1).squeeze(-1)
if self.null_value is not None:
# mask null values
null_mask = (y != self.null_value).squeeze(-1)
given *= null_mask
given_emb = self.given_emb(given)
val_time_emb += given_emb
if is_encoder:
for conv in self.downsize_convs:
val_time_emb = conv(val_time_emb)
length //= 2
# var embedding
var_idx = torch.Tensor([[i for j in range(length)] for i in range(d_y)])
var_idx = var_idx.long().to(x.device).view(-1).unsqueeze(0).repeat(bs, 1)
var_idx_true = var_idx.clone()
if not self.SPACE:
var_idx = torch.zeros_like(var_idx)
var_emb = self.var_emb(var_idx)
return val_time_emb, var_emb, var_idx_true