Source code for spacetimeformer.linear_model.linear_ar

import torch
from torch import nn


[docs]class LinearModel(nn.Module): def __init__(self, context_points: int): super().__init__() self.window = context_points self.linear = nn.Linear(context_points, 1)
[docs] def forward(self, y_c): bs, length, d_y = y_c.shape inp = y_c[:, -self.window :, :] inp = torch.cat(inp.chunk(d_y, dim=-1), dim=0) baseline = self.linear(inp.squeeze(-1)) baseline = torch.cat(baseline.chunk(d_y, dim=0), dim=-1).unsqueeze(1) return baseline