Source code for spacetimeformer.spacetimeformer_model.nn.powernorm

"""
PowerNorm code from https://github.com/sIncerass/powernorm/blob/master/fairseq/modules/norms/mask_powernorm.py
"""

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

__all__ = ["MaskPowerNorm"]


def _sum_ft(tensor):
    """sum over the first and last dimention"""
    return tensor.sum(dim=0).sum(dim=-1)


class GroupScaling1D(nn.Module):
    r"""Scales inputs by the second moment for the entire layer."""

    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f"eps={self.eps}, group={self.group_num}"

    def forward(self, input):
        # calculate second moment
        # different group use different mean
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = (
            torch.repeat_interleave(
                torch.mean(gn_input * gn_input, dim=3, keepdim=True), repeats=Cg, dim=-1
            )
            .contiguous()
            .reshape(T, B, C)
        )
        # divide out second moment
        return input / torch.sqrt(moment2 + self.eps)


def _unsqueeze_ft(tensor):
    """add new dimensions at the front and the tail"""
    return tensor.unsqueeze(0).unsqueeze(-1)


class PowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x,
        weight,
        bias,
        running_phi,
        eps,
        afwd,
        abkw,
        ema_gz,
        debug,
        warmup_iters,
        current_iter,
        mask_x,
    ):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x / (var + eps).sqrt()
        else:
            z = x / (running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(
                running_phi * (current_iter - 1) / current_iter
                + var.mean(dim=0, keepdim=True) / current_iter
            )
        running_phi.copy_(
            afwd * running_phi + (1 - afwd) * var.mean(dim=0, keepdim=True)
        )
        y = weight.reshape(1, C, 1, 1) * y + bias.reshape(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_variables

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = g - (1 - abkw) * ema_gz * z
        ema_gz.add_(
            (approx_grad_g * z)
            .mean(dim=3, keepdim=True)
            .mean(dim=2, keepdim=True)
            .mean(dim=0, keepdim=True)
        )

        gx = 1.0 / torch.sqrt(var + eps) * approx_grad_g
        return (
            gx,
            (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
            grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


[docs]class MaskPowerNorm(nn.Module): """ An implementation of masked batch normalization, used for testing the numerical stability. """ def __init__( self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9, affine=True, warmup_iters=10000, group_num=1, ): super().__init__() self.num_features = num_features self.eps = eps self.affine = affine self.register_parameter("weight", nn.Parameter(torch.ones(num_features))) self.register_parameter("bias", nn.Parameter(torch.zeros(num_features))) self.register_buffer("running_phi", torch.ones(1, num_features, 1, 1)) self.register_buffer("ema_gz", torch.zeros(1, num_features, 1, 1)) self.register_buffer("iters", torch.zeros(1).type(torch.LongTensor)) self.afwd = alpha_fwd self.abkw = alpha_bkw self.eps = eps self.debug = False self.warmup_iters = warmup_iters self.gp = GroupScaling1D(group_num=group_num) self.group_num = group_num
[docs] def extra_repr(self): return ( "{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, " "affine={affine}, warmup={warmup_iters}, group_num={group_num}".format( **self.__dict__ ) )
[docs] def forward(self, input, pad_mask=None, is_encoder=False): """ input: T x B x C -> B x C x T : B x C x T -> T x B x C pad_mask: B x T (padding is True) """ shaped_input = len(input.shape) == 2 if shaped_input: input = input.unsqueeze(0) T, B, C = input.shape input = self.gp(input) # construct the mask_input, size to be (BxL) x C: L is the real length here if pad_mask is None: mask_input = input.clone() else: # Transpose the bn_mask (B x T -> T x B) bn_mask = ~pad_mask bn_mask = bn_mask.transpose(0, 1) if pad_mask is not None: pad_size = (~bn_mask).sum() mask_input = input[bn_mask, :] else: mask_input = input.clone() mask_input = mask_input.reshape(-1, self.num_features) input = input.permute(1, 2, 0).contiguous() input_shape = input.size() input = input.reshape(input.size(0), self.num_features, -1) input = input.unsqueeze(-1) if self.training: self.iters.copy_(self.iters + 1) output = PowerFunction.apply( input, self.weight, self.bias, self.running_phi, self.eps, self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input, ) else: N, C, H, W = input.size() var = self.running_phi output = input / (var + self.eps).sqrt() output = self.weight.reshape(1, C, 1, 1) * output + self.bias.reshape( 1, C, 1, 1 ) output = output.reshape(input_shape) output = output.permute(2, 0, 1).contiguous() # Reshape it. if shaped_input: output = output.squeeze(0) return output