# import os # import random # from typing import Any # from typing import Callable # from typing import Optional import pandas as pd import glob import os import random # import attr import torch import torchvision import ws_resnet # from model_params import ModelParams ################ # main train utils def get_random_file(root_dir): # Search for all files in the directory and subdirectories file_list = glob.glob(os.path.join(root_dir, '**', '*'), recursive=True) # Filter out directories from the list file_list = [f for f in file_list if os.path.isfile(f)] # If there are no files found, return None or raise an exception if not file_list: raise FileNotFoundError("No files found in the specified directory") # Select and return a random file path return random.choice(file_list) ##################### # Parallelism utils # ##################### @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output class BatchShuffleDDP: @staticmethod @torch.no_grad() def shuffle(x): """ Batch shuffle, for making use of BatchNorm. *** Only support DistributedDataParallel (DDP) model. *** """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # random shuffle index idx_shuffle = torch.randperm(batch_size_all).to(x.device) # broadcast to all gpus torch.distributed.broadcast(idx_shuffle, src=0) # index for restoring idx_unshuffle = torch.argsort(idx_shuffle) # shuffled index for this gpu gpu_idx = torch.distributed.get_rank() idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this], idx_unshuffle @staticmethod @torch.no_grad() def unshuffle(x, idx_unshuffle): """ Undo batch shuffle. *** Only support DistributedDataParallel (DDP) model. *** """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # restored index for this gpu gpu_idx = torch.distributed.get_rank() idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this] ############### # Model utils # ############### class MLP(torch.nn.Module): def __init__( self, input_dim, output_dim, hidden_dim, num_layers, weight_standardization=False, normalization=None ): super().__init__() assert num_layers >= 0, "negative layers?!?" if normalization is not None: assert callable(normalization), "normalization must be callable" if num_layers == 0: self.net = torch.nn.Identity() return if num_layers == 1: self.net = torch.nn.Linear(input_dim, output_dim) return linear_net = ws_resnet.Linear if weight_standardization else torch.nn.Linear layers = [] prev_dim = input_dim for _ in range(num_layers - 1): layers.append(linear_net(prev_dim, hidden_dim)) if normalization is not None: layers.append(normalization()) layers.append(torch.nn.ReLU()) prev_dim = hidden_dim layers.append(torch.nn.Linear(hidden_dim, output_dim)) self.net = torch.nn.Sequential(*layers) def forward(self, x): return self.net(x) def get_encoder(name: str, **kwargs) -> torch.nn.Module: """ Gets just the encoder portion of a torchvision model (replaces final layer with identity) :param name: (str) name of the model :param kwargs: kwargs to send to the model :return: """ if name in ws_resnet.__dict__: model_creator = ws_resnet.__dict__.get(name) elif name in torchvision.models.__dict__: model_creator = torchvision.models.__dict__.get(name) else: raise AttributeError(f"Unknown architecture {name}") assert model_creator is not None, f"no torchvision model named {name}" model = model_creator(**kwargs) if hasattr(model, "fc"): # in resnet model.fc = torch.nn.Identity() elif hasattr(model, "classifier"): # not in resnet model.classifier = torch.nn.Identity() else: raise NotImplementedError(f"Unknown class {model.__class__}") return model #################### # Evaluation utils # #################### def calculate_accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def log_softmax_with_factors(logits: torch.Tensor, log_factor: float = 1, neg_factor: float = 1) -> torch.Tensor: exp_sum_neg_logits = torch.exp(logits).sum(dim=-1, keepdim=True) - torch.exp(logits) softmax_result = logits - log_factor * torch.log(torch.exp(logits) + neg_factor * exp_sum_neg_logits) return softmax_result