Source code for spacetimeformer.linear_model.linear_model

import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

import spacetimeformer as stf

from .linear_ar import LinearModel


[docs]class Linear_Forecaster(stf.Forecaster): def __init__( self, context_points: int, learning_rate: float = 1e-3, l2_coeff: float = 0, loss: str = "mse", linear_window: int = 0, ): super().__init__( l2_coeff=l2_coeff, learning_rate=learning_rate, loss=loss, linear_window=linear_window, ) self.model = LinearModel(context_points) @property def eval_step_forward_kwargs(self): return {} @property def train_step_forward_kwargs(self): return {}
[docs] def forward_model_pass(self, x_c, y_c, x_t, y_t): pred_len = y_t.shape[-2] output = torch.zeros_like(y_t).to(y_t.device) for i in range(pred_len): inp = torch.cat((y_c[:, i:], output[:, :i]), dim=-2) output[:, i] = self.model.forward(inp).squeeze(1) return (output,)
[docs] @classmethod def add_cli(self, parser): super().add_cli(parser)