286 lines
10 KiB
Python
286 lines
10 KiB
Python
import copy
|
|
import math
|
|
import warnings
|
|
from functools import partial
|
|
from typing import Optional
|
|
from typing import Union
|
|
|
|
import attr
|
|
import torch
|
|
import torch.nn.functional as F
|
|
# from pytorch_lightning.utilities import AttributeDict
|
|
from lightning.fabric.utilities.data import AttributeDict
|
|
from torch.utils.data import DataLoader
|
|
import lightning as L
|
|
|
|
import utils
|
|
from batchrenorm import BatchRenorm1d
|
|
from lars import LARS
|
|
from model_params import ModelParams
|
|
# from sklearn.linear_model import LogisticRegression
|
|
# from sklearn.cluster import KMeans
|
|
# from sklearn.metrics import rand_score, normalized_mutual_info_score
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
def get_mlp_normalization(hparams: ModelParams, prediction=False):
|
|
normalization_str = hparams.mlp_normalization
|
|
if prediction and hparams.prediction_mlp_normalization != "same":
|
|
normalization_str = hparams.prediction_mlp_normalization
|
|
|
|
if normalization_str is None:
|
|
return None
|
|
elif normalization_str == "bn":
|
|
return partial(torch.nn.BatchNorm1d, num_features=hparams.mlp_hidden_dim)
|
|
elif normalization_str == "br":
|
|
return partial(BatchRenorm1d, num_features=hparams.mlp_hidden_dim)
|
|
elif normalization_str == "ln":
|
|
return partial(torch.nn.LayerNorm, normalized_shape=[hparams.mlp_hidden_dim])
|
|
elif normalization_str == "gn":
|
|
return partial(torch.nn.GroupNorm, num_channels=hparams.mlp_hidden_dim, num_groups=32)
|
|
else:
|
|
raise NotImplementedError(f"mlp normalization {normalization_str} not implemented")
|
|
|
|
# class KMeansLoss:
|
|
#
|
|
# def __init__(self, num_clusters, embedding_dim, device):
|
|
# self.num_clusters = num_clusters
|
|
# self.centroids = torch.randn(num_clusters, embedding_dim, device=device)
|
|
# self.device=device
|
|
#
|
|
# def update_centroids(self, embeddings, assignments):
|
|
# for i in range(self.num_clusters):
|
|
# assigned_embeddings = embeddings[assignments == i]
|
|
# if len(assigned_embeddings) > 1: # good if more than singleton
|
|
# # implement ewma update for centroids
|
|
# weight1 = torch.tensor(0.3, device='cpu')
|
|
# weight2 = torch.tensor(0.7, device='cpu') # give more weight to new embeddings
|
|
# self.centroids[i] = self.centroids[i] * weight1 + assigned_embeddings.mean(dim=0).cpu() * weight2
|
|
#
|
|
# def set_centroids(self, embeddings, assignments):
|
|
# for i in range(self.num_clusters):
|
|
# assigned_embeddings = embeddings[assignments == i]
|
|
# if len(assigned_embeddings) > 1: # good if more than singleton
|
|
# # implement ewma update for centroids
|
|
# self.centroids[i] = assigned_embeddings.mean(dim=0).cpu()
|
|
#
|
|
#
|
|
# def compute_loss(self, embeddings):
|
|
# # move centroids to same device as embeddings
|
|
# centroids = self.centroids.to(embeddings.device)
|
|
# distances = torch.cdist(embeddings, centroids, p=self.num_clusters)
|
|
# min_distances, assignments = distances.min(dim=1)
|
|
# loss = min_distances.pow(2).sum()
|
|
# return loss, assignments
|
|
#
|
|
# def forward(self, embeddings, step_count):
|
|
# loss, assignments = self.compute_loss(embeddings)
|
|
# detached_embeddings = embeddings.detach()
|
|
# detached_assignments = assignments.detach()
|
|
#
|
|
# if (step_count < 5):
|
|
# self.set_centroids(detached_embeddings, detached_assignments)
|
|
# if (step_count % 2 == 0):
|
|
# self.update_centroids(detached_embeddings, detached_assignments)
|
|
# return loss
|
|
|
|
|
|
class SelfSupervisedMethod(L.LightningModule):
|
|
model: torch.nn.Module
|
|
hparams: AttributeDict
|
|
embedding_dim: Optional[int]
|
|
|
|
def __init__(
|
|
self,
|
|
hparams: Union[ModelParams, dict, None] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
# disable automatic optimization for lightning2
|
|
self.automatic_optimization = False
|
|
self.optimizer = None
|
|
self.lr_scheduler = None
|
|
|
|
|
|
# load from arguments
|
|
if hparams is None:
|
|
hparams = self.params(**kwargs)
|
|
# if it is already an attributedict, then use it directly
|
|
if hparams is not None:
|
|
self.save_hyperparameters(attr.asdict(hparams))
|
|
|
|
|
|
# Create encoder model
|
|
self.model = utils.get_encoder(hparams.encoder_arch)
|
|
|
|
# projection_mlp_layers = 3
|
|
self.projection_model = utils.MLP(
|
|
hparams.embedding_dim,
|
|
hparams.dim,
|
|
hparams.mlp_hidden_dim,
|
|
num_layers=hparams.projection_mlp_layers,
|
|
normalization=get_mlp_normalization(hparams),
|
|
weight_standardization=hparams.use_mlp_weight_standardization,
|
|
)
|
|
|
|
# by default it is identity
|
|
# prediction_mlp_layers = 0
|
|
self.prediction_model = utils.MLP(
|
|
hparams.dim,
|
|
hparams.dim,
|
|
hparams.mlp_hidden_dim,
|
|
num_layers=hparams.prediction_mlp_layers,
|
|
normalization=get_mlp_normalization(hparams, prediction=True),
|
|
weight_standardization=hparams.use_mlp_weight_standardization,
|
|
)
|
|
|
|
# kmeans loss
|
|
# self.kmeans_loss = KMeansLoss(num_clusters=hparams.num_clusters, embedding_dim=hparams.dim, device=self.device)
|
|
|
|
|
|
def _get_embeddings(self, x):
|
|
"""
|
|
Input:
|
|
im_q: a batch of query images
|
|
im_k: a batch of key images
|
|
Output:
|
|
embeddings, targets
|
|
"""
|
|
bsz, nd, nc, nh, nw = x.shape
|
|
assert nd == 2, "second dimension should be the split image -- dims should be N2CHW"
|
|
im_q = x[:, 0].contiguous()
|
|
im_k = x[:, 1].contiguous()
|
|
|
|
# compute query features
|
|
emb_q = self.model(im_q)
|
|
q_projection = self.projection_model(emb_q)
|
|
# by default vicreg gives an identity for prediction model
|
|
q = self.prediction_model(q_projection) # queries: NxC
|
|
emb_k = self.model(im_k)
|
|
k_projection = self.projection_model(emb_k)
|
|
k = self.prediction_model(k_projection) # queries: NxC
|
|
# q and k are the projection embeddings
|
|
|
|
return emb_q, q, k
|
|
|
|
|
|
def _get_vicreg_loss(self, z_a, z_b, batch_idx):
|
|
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
|
|
|
|
# invariance loss
|
|
loss_inv = F.mse_loss(z_a, z_b)
|
|
|
|
# variance loss
|
|
std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon)
|
|
std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon)
|
|
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
|
|
loss_v_b = torch.mean(F.relu(1 - std_z_b))
|
|
loss_var = loss_v_a + loss_v_b
|
|
|
|
# covariance loss
|
|
N, D = z_a.shape
|
|
z_a = z_a - z_a.mean(dim=0)
|
|
z_b = z_b - z_b.mean(dim=0)
|
|
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
|
|
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
|
|
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
|
|
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
|
|
loss_cov = loss_c_a + loss_c_b
|
|
|
|
weighted_inv = loss_inv * self.hparams.invariance_loss_weight
|
|
weighted_var = loss_var * self.hparams.variance_loss_weight
|
|
weighted_cov = loss_cov * self.hparams.covariance_loss_weight
|
|
|
|
loss = weighted_inv + weighted_var + weighted_cov
|
|
|
|
|
|
return loss
|
|
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, class_labels = batch # batch is a tuple, we just want the image
|
|
|
|
emb_q, q, k = self._get_embeddings(x)
|
|
|
|
vicreg_loss = self._get_vicreg_loss(q, k, batch_idx)
|
|
|
|
total_loss = vicreg_loss.mean() * self.hparams.loss_constant_factor
|
|
|
|
# here lies the manual optimizing code
|
|
self.optimizer.zero_grad()
|
|
self.manual_backward(total_loss)
|
|
self.optimizer.step()
|
|
self.lr_scheduler.step()
|
|
|
|
log_data = {
|
|
"step_train_loss": total_loss,
|
|
}
|
|
|
|
self.log_dict(log_data, sync_dist=True, prog_bar=True)
|
|
return {"loss": total_loss}
|
|
|
|
def configure_optimizers(self):
|
|
# exclude bias and batch norm from LARS and weight decay
|
|
regular_parameters = []
|
|
regular_parameter_names = []
|
|
excluded_parameters = []
|
|
excluded_parameter_names = []
|
|
for name, parameter in self.named_parameters():
|
|
if parameter.requires_grad is False:
|
|
continue
|
|
|
|
# for vicreg
|
|
# exclude_matching_parameters_from_lars=[".bias", ".bn"],
|
|
if any(x in name for x in self.hparams.exclude_matching_parameters_from_lars):
|
|
excluded_parameters.append(parameter)
|
|
excluded_parameter_names.append(name)
|
|
else:
|
|
regular_parameters.append(parameter)
|
|
regular_parameter_names.append(name)
|
|
|
|
param_groups = [
|
|
{
|
|
"params": regular_parameters,
|
|
"names": regular_parameter_names,
|
|
"use_lars": True
|
|
},
|
|
{
|
|
"params": excluded_parameters,
|
|
"names": excluded_parameter_names,
|
|
"use_lars": False,
|
|
"weight_decay": 0,
|
|
},
|
|
]
|
|
if self.hparams.optimizer_name == "sgd":
|
|
optimizer = torch.optim.SGD
|
|
elif self.hparams.optimizer_name == "lars":
|
|
optimizer = partial(LARS, warmup_epochs=self.hparams.lars_warmup_epochs, eta=self.hparams.lars_eta)
|
|
else:
|
|
raise NotImplementedError(f"No such optimizer {self.hparams.optimizer_name}")
|
|
|
|
self.optimizer = optimizer(
|
|
param_groups,
|
|
lr=self.hparams.lr,
|
|
momentum=self.hparams.momentum,
|
|
weight_decay=self.hparams.weight_decay,
|
|
)
|
|
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
self.optimizer,
|
|
self.hparams.max_epochs,
|
|
eta_min=self.hparams.final_lr_schedule_value,
|
|
)
|
|
return None # [encoding_optimizer], [self.lr_scheduler]
|
|
|
|
|
|
|
|
@classmethod
|
|
def params(cls, **kwargs) -> ModelParams:
|
|
return ModelParams(**kwargs)
|