200 lines
5.7 KiB
Python
200 lines
5.7 KiB
Python
# import os
|
|
# import random
|
|
# from typing import Any
|
|
# from typing import Callable
|
|
# from typing import Optional
|
|
|
|
import pandas as pd
|
|
import glob
|
|
import os
|
|
import random
|
|
|
|
# import attr
|
|
import torch
|
|
import torchvision
|
|
|
|
import ws_resnet
|
|
# from model_params import ModelParams
|
|
|
|
################
|
|
# main train utils
|
|
|
|
def get_random_file(root_dir):
|
|
# Search for all files in the directory and subdirectories
|
|
file_list = glob.glob(os.path.join(root_dir, '**', '*'), recursive=True)
|
|
# Filter out directories from the list
|
|
file_list = [f for f in file_list if os.path.isfile(f)]
|
|
# If there are no files found, return None or raise an exception
|
|
if not file_list:
|
|
raise FileNotFoundError("No files found in the specified directory")
|
|
# Select and return a random file path
|
|
return random.choice(file_list)
|
|
|
|
|
|
|
|
|
|
|
|
#####################
|
|
# Parallelism utils #
|
|
#####################
|
|
|
|
|
|
@torch.no_grad()
|
|
def concat_all_gather(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|
|
|
|
|
|
class BatchShuffleDDP:
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def shuffle(x):
|
|
"""
|
|
Batch shuffle, for making use of BatchNorm.
|
|
*** Only support DistributedDataParallel (DDP) model. ***
|
|
"""
|
|
# gather from all gpus
|
|
batch_size_this = x.shape[0]
|
|
x_gather = concat_all_gather(x)
|
|
batch_size_all = x_gather.shape[0]
|
|
|
|
num_gpus = batch_size_all // batch_size_this
|
|
|
|
# random shuffle index
|
|
idx_shuffle = torch.randperm(batch_size_all).to(x.device)
|
|
|
|
# broadcast to all gpus
|
|
torch.distributed.broadcast(idx_shuffle, src=0)
|
|
|
|
# index for restoring
|
|
idx_unshuffle = torch.argsort(idx_shuffle)
|
|
|
|
# shuffled index for this gpu
|
|
gpu_idx = torch.distributed.get_rank()
|
|
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
|
|
|
return x_gather[idx_this], idx_unshuffle
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def unshuffle(x, idx_unshuffle):
|
|
"""
|
|
Undo batch shuffle.
|
|
*** Only support DistributedDataParallel (DDP) model. ***
|
|
"""
|
|
# gather from all gpus
|
|
batch_size_this = x.shape[0]
|
|
x_gather = concat_all_gather(x)
|
|
batch_size_all = x_gather.shape[0]
|
|
|
|
num_gpus = batch_size_all // batch_size_this
|
|
|
|
# restored index for this gpu
|
|
gpu_idx = torch.distributed.get_rank()
|
|
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
|
|
|
return x_gather[idx_this]
|
|
|
|
|
|
###############
|
|
# Model utils #
|
|
###############
|
|
|
|
|
|
class MLP(torch.nn.Module):
|
|
def __init__(
|
|
self, input_dim, output_dim, hidden_dim, num_layers, weight_standardization=False, normalization=None
|
|
):
|
|
super().__init__()
|
|
assert num_layers >= 0, "negative layers?!?"
|
|
if normalization is not None:
|
|
assert callable(normalization), "normalization must be callable"
|
|
|
|
if num_layers == 0:
|
|
self.net = torch.nn.Identity()
|
|
return
|
|
|
|
if num_layers == 1:
|
|
self.net = torch.nn.Linear(input_dim, output_dim)
|
|
return
|
|
|
|
linear_net = ws_resnet.Linear if weight_standardization else torch.nn.Linear
|
|
|
|
layers = []
|
|
prev_dim = input_dim
|
|
for _ in range(num_layers - 1):
|
|
layers.append(linear_net(prev_dim, hidden_dim))
|
|
if normalization is not None:
|
|
layers.append(normalization())
|
|
layers.append(torch.nn.ReLU())
|
|
prev_dim = hidden_dim
|
|
|
|
layers.append(torch.nn.Linear(hidden_dim, output_dim))
|
|
|
|
self.net = torch.nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
def get_encoder(name: str, **kwargs) -> torch.nn.Module:
|
|
"""
|
|
Gets just the encoder portion of a torchvision model (replaces final layer with identity)
|
|
:param name: (str) name of the model
|
|
:param kwargs: kwargs to send to the model
|
|
:return:
|
|
"""
|
|
|
|
if name in ws_resnet.__dict__:
|
|
model_creator = ws_resnet.__dict__.get(name)
|
|
elif name in torchvision.models.__dict__:
|
|
model_creator = torchvision.models.__dict__.get(name)
|
|
else:
|
|
raise AttributeError(f"Unknown architecture {name}")
|
|
|
|
assert model_creator is not None, f"no torchvision model named {name}"
|
|
model = model_creator(**kwargs)
|
|
if hasattr(model, "fc"): # in resnet
|
|
model.fc = torch.nn.Identity()
|
|
elif hasattr(model, "classifier"): # not in resnet
|
|
model.classifier = torch.nn.Identity()
|
|
else:
|
|
raise NotImplementedError(f"Unknown class {model.__class__}")
|
|
|
|
return model
|
|
|
|
|
|
####################
|
|
# Evaluation utils #
|
|
####################
|
|
|
|
|
|
def calculate_accuracy(output, target, topk=(1,)):
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
with torch.no_grad():
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|
|
|
|
|
|
def log_softmax_with_factors(logits: torch.Tensor, log_factor: float = 1, neg_factor: float = 1) -> torch.Tensor:
|
|
exp_sum_neg_logits = torch.exp(logits).sum(dim=-1, keepdim=True) - torch.exp(logits)
|
|
softmax_result = logits - log_factor * torch.log(torch.exp(logits) + neg_factor * exp_sum_neg_logits)
|
|
return softmax_result
|