89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
"""
|
|
From https://github.com/ludvb/batchrenorm
|
|
@article{batchrenomalization,
|
|
author = {Sergey Ioffe},
|
|
title = {Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models},
|
|
journal = {arXiv preprint arXiv:1702.03275},
|
|
year = {2017},
|
|
}
|
|
"""
|
|
|
|
import torch
|
|
|
|
__all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"]
|
|
|
|
|
|
class BatchRenorm(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-3,
|
|
momentum: float = 0.01,
|
|
affine: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float))
|
|
self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float))
|
|
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
|
|
self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
|
|
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))
|
|
self.affine = affine
|
|
self.eps = eps
|
|
self.step = 0
|
|
self.momentum = momentum
|
|
|
|
def _check_input_dim(self, x: torch.Tensor) -> None:
|
|
raise NotImplementedError() # pragma: no cover
|
|
|
|
@property
|
|
def rmax(self) -> torch.Tensor:
|
|
return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0)
|
|
|
|
@property
|
|
def dmax(self) -> torch.Tensor:
|
|
return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self._check_input_dim(x)
|
|
if x.dim() > 2:
|
|
x = x.transpose(1, -1)
|
|
if self.training:
|
|
dims = [i for i in range(x.dim() - 1)]
|
|
batch_mean = x.mean(dims)
|
|
batch_std = x.std(dims, unbiased=False) + self.eps
|
|
r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(
|
|
1 / self.rmax.item(), self.rmax.item()
|
|
)
|
|
d = (
|
|
(batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std)
|
|
).clamp_(-self.dmax.item(), self.dmax.item())
|
|
x = (x - batch_mean) / batch_std * r + d
|
|
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
|
|
self.running_std += self.momentum * (batch_std.detach() - self.running_std)
|
|
self.num_batches_tracked += 1
|
|
else:
|
|
x = (x - self.running_mean) / self.running_std
|
|
if self.affine:
|
|
x = self.weight * x + self.bias
|
|
if x.dim() > 2:
|
|
x = x.transpose(1, -1)
|
|
return x
|
|
|
|
|
|
class BatchRenorm1d(BatchRenorm):
|
|
def _check_input_dim(self, x: torch.Tensor) -> None:
|
|
if x.dim() not in [2, 3]:
|
|
raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")
|
|
|
|
|
|
class BatchRenorm2d(BatchRenorm):
|
|
def _check_input_dim(self, x: torch.Tensor) -> None:
|
|
if x.dim() != 4:
|
|
raise ValueError("expected 4D input (got {x.dim()}D input)")
|
|
|
|
|
|
class BatchRenorm3d(BatchRenorm):
|
|
def _check_input_dim(self, x: torch.Tensor) -> None:
|
|
if x.dim() != 5:
|
|
raise ValueError("expected 5D input (got {x.dim()}D input)")
|