import torch
from torch import nn
[docs]class Time2Vec(nn.Module):
def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin):
assert embed_dim % input_dim == 0
super(Time2Vec, self).__init__()
self.enabled = embed_dim > 0
if self.enabled:
self.embed_dim = embed_dim // input_dim
self.input_dim = input_dim
self.embed_weight = nn.parameter.Parameter(
torch.randn(self.input_dim, self.embed_dim)
)
self.embed_bias = nn.parameter.Parameter(torch.randn(self.embed_dim))
self.act_function = act_function
[docs] def forward(self, x):
if self.enabled:
x = torch.diag_embed(x)
# x.shape = (bs, sequence_length, input_dim, input_dim)
x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
# x_affine.shape = (bs, sequence_length, input_dim, time_embed_dim)
x_affine_0, x_affine_remain = torch.split(
x_affine, [1, self.embed_dim - 1], dim=-1
)
x_affine_remain = self.act_function(x_affine_remain)
x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
x_output = x_output.view(x_output.size(0), x_output.size(1), -1)
# x_output.shape = (bs, sequence_length, input_dim * time_embed_dim)
else:
x_output = x
return x_output