""" From https://github.com/ludvb/batchrenorm @article{batchrenomalization, author = {Sergey Ioffe}, title = {Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models}, journal = {arXiv preprint arXiv:1702.03275}, year = {2017}, } """ import torch __all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"] class BatchRenorm(torch.nn.Module): def __init__( self, num_features: int, eps: float = 1e-3, momentum: float = 0.01, affine: bool = True, ): super().__init__() self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float)) self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float)) self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float)) self.affine = affine self.eps = eps self.step = 0 self.momentum = momentum def _check_input_dim(self, x: torch.Tensor) -> None: raise NotImplementedError() # pragma: no cover @property def rmax(self) -> torch.Tensor: return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0) @property def dmax(self) -> torch.Tensor: return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0) def forward(self, x: torch.Tensor) -> torch.Tensor: self._check_input_dim(x) if x.dim() > 2: x = x.transpose(1, -1) if self.training: dims = [i for i in range(x.dim() - 1)] batch_mean = x.mean(dims) batch_std = x.std(dims, unbiased=False) + self.eps r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_( 1 / self.rmax.item(), self.rmax.item() ) d = ( (batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std) ).clamp_(-self.dmax.item(), self.dmax.item()) x = (x - batch_mean) / batch_std * r + d self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean) self.running_std += self.momentum * (batch_std.detach() - self.running_std) self.num_batches_tracked += 1 else: x = (x - self.running_mean) / self.running_std if self.affine: x = self.weight * x + self.bias if x.dim() > 2: x = x.transpose(1, -1) return x class BatchRenorm1d(BatchRenorm): def _check_input_dim(self, x: torch.Tensor) -> None: if x.dim() not in [2, 3]: raise ValueError("expected 2D or 3D input (got {x.dim()}D input)") class BatchRenorm2d(BatchRenorm): def _check_input_dim(self, x: torch.Tensor) -> None: if x.dim() != 4: raise ValueError("expected 4D input (got {x.dim()}D input)") class BatchRenorm3d(BatchRenorm): def _check_input_dim(self, x: torch.Tensor) -> None: if x.dim() != 5: raise ValueError("expected 5D input (got {x.dim()}D input)")