Source code for spacetimeformer.spacetimeformer_model.nn.data_dropout

import torch
from torch import nn


[docs]class DataDropout(nn.Module): def __init__(self, dropout=None): super().__init__() self.dropout = dropout
[docs] def forward(self, embed): bs, length, d_model = embed.shape if self.training: mask = torch.bernoulli((1.0 - self.dropout) * torch.ones(bs, length, 1)) mask.requires_grad = False mask = mask.to(embed.device) return embed * mask else: return embed
def __repr__(self): return f"DataDropout(p = {self.dropout})"