semes_gaf/self_supervised/batchrenorm.py

89 lines
3.2 KiB
Python

"""
From https://github.com/ludvb/batchrenorm
@article{batchrenomalization,
author = {Sergey Ioffe},
title = {Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models},
journal = {arXiv preprint arXiv:1702.03275},
year = {2017},
}
"""
import torch
__all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"]
class BatchRenorm(torch.nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-3,
momentum: float = 0.01,
affine: bool = True,
):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float))
self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))
self.affine = affine
self.eps = eps
self.step = 0
self.momentum = momentum
def _check_input_dim(self, x: torch.Tensor) -> None:
raise NotImplementedError() # pragma: no cover
@property
def rmax(self) -> torch.Tensor:
return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0)
@property
def dmax(self) -> torch.Tensor:
return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self._check_input_dim(x)
if x.dim() > 2:
x = x.transpose(1, -1)
if self.training:
dims = [i for i in range(x.dim() - 1)]
batch_mean = x.mean(dims)
batch_std = x.std(dims, unbiased=False) + self.eps
r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(
1 / self.rmax.item(), self.rmax.item()
)
d = (
(batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std)
).clamp_(-self.dmax.item(), self.dmax.item())
x = (x - batch_mean) / batch_std * r + d
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
self.running_std += self.momentum * (batch_std.detach() - self.running_std)
self.num_batches_tracked += 1
else:
x = (x - self.running_mean) / self.running_std
if self.affine:
x = self.weight * x + self.bias
if x.dim() > 2:
x = x.transpose(1, -1)
return x
class BatchRenorm1d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() not in [2, 3]:
raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")
class BatchRenorm2d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() != 4:
raise ValueError("expected 4D input (got {x.dim()}D input)")
class BatchRenorm3d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() != 5:
raise ValueError("expected 5D input (got {x.dim()}D input)")