import torch
from torch import nn
[docs]class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.g = nn.Parameter(torch.ones(1))
self.eps = eps
[docs] def forward(self, x):
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) * self.scale
x = x / n * self.g
return x