semes_gaf/self_supervised/utils.py

200 lines
5.7 KiB
Python

# import os
# import random
# from typing import Any
# from typing import Callable
# from typing import Optional
import pandas as pd
import glob
import os
import random
# import attr
import torch
import torchvision
import ws_resnet
# from model_params import ModelParams
################
# main train utils
def get_random_file(root_dir):
# Search for all files in the directory and subdirectories
file_list = glob.glob(os.path.join(root_dir, '**', '*'), recursive=True)
# Filter out directories from the list
file_list = [f for f in file_list if os.path.isfile(f)]
# If there are no files found, return None or raise an exception
if not file_list:
raise FileNotFoundError("No files found in the specified directory")
# Select and return a random file path
return random.choice(file_list)
#####################
# Parallelism utils #
#####################
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
class BatchShuffleDDP:
@staticmethod
@torch.no_grad()
def shuffle(x):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).to(x.device)
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@staticmethod
@torch.no_grad()
def unshuffle(x, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]
###############
# Model utils #
###############
class MLP(torch.nn.Module):
def __init__(
self, input_dim, output_dim, hidden_dim, num_layers, weight_standardization=False, normalization=None
):
super().__init__()
assert num_layers >= 0, "negative layers?!?"
if normalization is not None:
assert callable(normalization), "normalization must be callable"
if num_layers == 0:
self.net = torch.nn.Identity()
return
if num_layers == 1:
self.net = torch.nn.Linear(input_dim, output_dim)
return
linear_net = ws_resnet.Linear if weight_standardization else torch.nn.Linear
layers = []
prev_dim = input_dim
for _ in range(num_layers - 1):
layers.append(linear_net(prev_dim, hidden_dim))
if normalization is not None:
layers.append(normalization())
layers.append(torch.nn.ReLU())
prev_dim = hidden_dim
layers.append(torch.nn.Linear(hidden_dim, output_dim))
self.net = torch.nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def get_encoder(name: str, **kwargs) -> torch.nn.Module:
"""
Gets just the encoder portion of a torchvision model (replaces final layer with identity)
:param name: (str) name of the model
:param kwargs: kwargs to send to the model
:return:
"""
if name in ws_resnet.__dict__:
model_creator = ws_resnet.__dict__.get(name)
elif name in torchvision.models.__dict__:
model_creator = torchvision.models.__dict__.get(name)
else:
raise AttributeError(f"Unknown architecture {name}")
assert model_creator is not None, f"no torchvision model named {name}"
model = model_creator(**kwargs)
if hasattr(model, "fc"): # in resnet
model.fc = torch.nn.Identity()
elif hasattr(model, "classifier"): # not in resnet
model.classifier = torch.nn.Identity()
else:
raise NotImplementedError(f"Unknown class {model.__class__}")
return model
####################
# Evaluation utils #
####################
def calculate_accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def log_softmax_with_factors(logits: torch.Tensor, log_factor: float = 1, neg_factor: float = 1) -> torch.Tensor:
exp_sum_neg_logits = torch.exp(logits).sum(dim=-1, keepdim=True) - torch.exp(logits)
softmax_result = logits - log_factor * torch.log(torch.exp(logits) + neg_factor * exp_sum_neg_logits)
return softmax_result