semes_gaf/self_supervised/vicreg.py

286 lines
10 KiB
Python

import copy
import math
import warnings
from functools import partial
from typing import Optional
from typing import Union
import attr
import torch
import torch.nn.functional as F
# from pytorch_lightning.utilities import AttributeDict
from lightning.fabric.utilities.data import AttributeDict
from torch.utils.data import DataLoader
import lightning as L
import utils
from batchrenorm import BatchRenorm1d
from lars import LARS
from model_params import ModelParams
# from sklearn.linear_model import LogisticRegression
# from sklearn.cluster import KMeans
# from sklearn.metrics import rand_score, normalized_mutual_info_score
import pandas as pd
def get_mlp_normalization(hparams: ModelParams, prediction=False):
normalization_str = hparams.mlp_normalization
if prediction and hparams.prediction_mlp_normalization != "same":
normalization_str = hparams.prediction_mlp_normalization
if normalization_str is None:
return None
elif normalization_str == "bn":
return partial(torch.nn.BatchNorm1d, num_features=hparams.mlp_hidden_dim)
elif normalization_str == "br":
return partial(BatchRenorm1d, num_features=hparams.mlp_hidden_dim)
elif normalization_str == "ln":
return partial(torch.nn.LayerNorm, normalized_shape=[hparams.mlp_hidden_dim])
elif normalization_str == "gn":
return partial(torch.nn.GroupNorm, num_channels=hparams.mlp_hidden_dim, num_groups=32)
else:
raise NotImplementedError(f"mlp normalization {normalization_str} not implemented")
# class KMeansLoss:
#
# def __init__(self, num_clusters, embedding_dim, device):
# self.num_clusters = num_clusters
# self.centroids = torch.randn(num_clusters, embedding_dim, device=device)
# self.device=device
#
# def update_centroids(self, embeddings, assignments):
# for i in range(self.num_clusters):
# assigned_embeddings = embeddings[assignments == i]
# if len(assigned_embeddings) > 1: # good if more than singleton
# # implement ewma update for centroids
# weight1 = torch.tensor(0.3, device='cpu')
# weight2 = torch.tensor(0.7, device='cpu') # give more weight to new embeddings
# self.centroids[i] = self.centroids[i] * weight1 + assigned_embeddings.mean(dim=0).cpu() * weight2
#
# def set_centroids(self, embeddings, assignments):
# for i in range(self.num_clusters):
# assigned_embeddings = embeddings[assignments == i]
# if len(assigned_embeddings) > 1: # good if more than singleton
# # implement ewma update for centroids
# self.centroids[i] = assigned_embeddings.mean(dim=0).cpu()
#
#
# def compute_loss(self, embeddings):
# # move centroids to same device as embeddings
# centroids = self.centroids.to(embeddings.device)
# distances = torch.cdist(embeddings, centroids, p=self.num_clusters)
# min_distances, assignments = distances.min(dim=1)
# loss = min_distances.pow(2).sum()
# return loss, assignments
#
# def forward(self, embeddings, step_count):
# loss, assignments = self.compute_loss(embeddings)
# detached_embeddings = embeddings.detach()
# detached_assignments = assignments.detach()
#
# if (step_count < 5):
# self.set_centroids(detached_embeddings, detached_assignments)
# if (step_count % 2 == 0):
# self.update_centroids(detached_embeddings, detached_assignments)
# return loss
class SelfSupervisedMethod(L.LightningModule):
model: torch.nn.Module
hparams: AttributeDict
embedding_dim: Optional[int]
def __init__(
self,
hparams: Union[ModelParams, dict, None] = None,
**kwargs,
):
super().__init__()
# disable automatic optimization for lightning2
self.automatic_optimization = False
self.optimizer = None
self.lr_scheduler = None
# load from arguments
if hparams is None:
hparams = self.params(**kwargs)
# if it is already an attributedict, then use it directly
if hparams is not None:
self.save_hyperparameters(attr.asdict(hparams))
# Create encoder model
self.model = utils.get_encoder(hparams.encoder_arch)
# projection_mlp_layers = 3
self.projection_model = utils.MLP(
hparams.embedding_dim,
hparams.dim,
hparams.mlp_hidden_dim,
num_layers=hparams.projection_mlp_layers,
normalization=get_mlp_normalization(hparams),
weight_standardization=hparams.use_mlp_weight_standardization,
)
# by default it is identity
# prediction_mlp_layers = 0
self.prediction_model = utils.MLP(
hparams.dim,
hparams.dim,
hparams.mlp_hidden_dim,
num_layers=hparams.prediction_mlp_layers,
normalization=get_mlp_normalization(hparams, prediction=True),
weight_standardization=hparams.use_mlp_weight_standardization,
)
# kmeans loss
# self.kmeans_loss = KMeansLoss(num_clusters=hparams.num_clusters, embedding_dim=hparams.dim, device=self.device)
def _get_embeddings(self, x):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
embeddings, targets
"""
bsz, nd, nc, nh, nw = x.shape
assert nd == 2, "second dimension should be the split image -- dims should be N2CHW"
im_q = x[:, 0].contiguous()
im_k = x[:, 1].contiguous()
# compute query features
emb_q = self.model(im_q)
q_projection = self.projection_model(emb_q)
# by default vicreg gives an identity for prediction model
q = self.prediction_model(q_projection) # queries: NxC
emb_k = self.model(im_k)
k_projection = self.projection_model(emb_k)
k = self.prediction_model(k_projection) # queries: NxC
# q and k are the projection embeddings
return emb_q, q, k
def _get_vicreg_loss(self, z_a, z_b, batch_idx):
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
# invariance loss
loss_inv = F.mse_loss(z_a, z_b)
# variance loss
std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon)
std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon)
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
loss_v_b = torch.mean(F.relu(1 - std_z_b))
loss_var = loss_v_a + loss_v_b
# covariance loss
N, D = z_a.shape
z_a = z_a - z_a.mean(dim=0)
z_b = z_b - z_b.mean(dim=0)
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
loss_cov = loss_c_a + loss_c_b
weighted_inv = loss_inv * self.hparams.invariance_loss_weight
weighted_var = loss_var * self.hparams.variance_loss_weight
weighted_cov = loss_cov * self.hparams.covariance_loss_weight
loss = weighted_inv + weighted_var + weighted_cov
return loss
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, class_labels = batch # batch is a tuple, we just want the image
emb_q, q, k = self._get_embeddings(x)
vicreg_loss = self._get_vicreg_loss(q, k, batch_idx)
total_loss = vicreg_loss.mean() * self.hparams.loss_constant_factor
# here lies the manual optimizing code
self.optimizer.zero_grad()
self.manual_backward(total_loss)
self.optimizer.step()
self.lr_scheduler.step()
log_data = {
"step_train_loss": total_loss,
}
self.log_dict(log_data, sync_dist=True, prog_bar=True)
return {"loss": total_loss}
def configure_optimizers(self):
# exclude bias and batch norm from LARS and weight decay
regular_parameters = []
regular_parameter_names = []
excluded_parameters = []
excluded_parameter_names = []
for name, parameter in self.named_parameters():
if parameter.requires_grad is False:
continue
# for vicreg
# exclude_matching_parameters_from_lars=[".bias", ".bn"],
if any(x in name for x in self.hparams.exclude_matching_parameters_from_lars):
excluded_parameters.append(parameter)
excluded_parameter_names.append(name)
else:
regular_parameters.append(parameter)
regular_parameter_names.append(name)
param_groups = [
{
"params": regular_parameters,
"names": regular_parameter_names,
"use_lars": True
},
{
"params": excluded_parameters,
"names": excluded_parameter_names,
"use_lars": False,
"weight_decay": 0,
},
]
if self.hparams.optimizer_name == "sgd":
optimizer = torch.optim.SGD
elif self.hparams.optimizer_name == "lars":
optimizer = partial(LARS, warmup_epochs=self.hparams.lars_warmup_epochs, eta=self.hparams.lars_eta)
else:
raise NotImplementedError(f"No such optimizer {self.hparams.optimizer_name}")
self.optimizer = optimizer(
param_groups,
lr=self.hparams.lr,
momentum=self.hparams.momentum,
weight_decay=self.hparams.weight_decay,
)
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
self.hparams.max_epochs,
eta_min=self.hparams.final_lr_schedule_value,
)
return None # [encoding_optimizer], [self.lr_scheduler]
@classmethod
def params(cls, **kwargs) -> ModelParams:
return ModelParams(**kwargs)