Source code for spacetimeformer.time2vec

import torch
from torch import nn


[docs]class Time2Vec(nn.Module): def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin): assert embed_dim % input_dim == 0 super(Time2Vec, self).__init__() self.enabled = embed_dim > 0 if self.enabled: self.embed_dim = embed_dim // input_dim self.input_dim = input_dim self.embed_weight = nn.parameter.Parameter( torch.randn(self.input_dim, self.embed_dim) ) self.embed_bias = nn.parameter.Parameter(torch.randn(self.embed_dim)) self.act_function = act_function
[docs] def forward(self, x): if self.enabled: x = torch.diag_embed(x) # x.shape = (bs, sequence_length, input_dim, input_dim) x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias # x_affine.shape = (bs, sequence_length, input_dim, time_embed_dim) x_affine_0, x_affine_remain = torch.split( x_affine, [1, self.embed_dim - 1], dim=-1 ) x_affine_remain = self.act_function(x_affine_remain) x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1) x_output = x_output.view(x_output.size(0), x_output.size(1), -1) # x_output.shape = (bs, sequence_length, input_dim * time_embed_dim) else: x_output = x return x_output