Source code for spacetimeformer.spacetimeformer_model.nn.scalenorm

import torch
from torch import nn


[docs]class ScaleNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.scale = dim ** -0.5 self.g = nn.Parameter(torch.ones(1)) self.eps = eps
[docs] def forward(self, x): n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) * self.scale x = x / n * self.g return x