Feat: semes gaf time series classifier

This commit is contained in:
Richard Wong 2024-08-26 15:52:22 +09:00
commit 14405c7285
19 changed files with 49756 additions and 0 deletions

0
.gitignore vendored Normal file
View File

1
data/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
testlog*

2
gaf_data/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
test
train

File diff suppressed because one or more lines are too long

145
self_supervised/.gitignore vendored Normal file
View File

@ -0,0 +1,145 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# PyCharm
.idea/
# misc
stl10_binary.tar.gz
stl10_binary/
tb_logs/
eigen_values.csv
tb_archive/
tb_archive_23-12-16/
output_data/
lightning_logs/
custom/
checkpoint_semes/
archive

21
self_supervised/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 Untiled AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

290
self_supervised/README.md Normal file
View File

@ -0,0 +1,290 @@
# PyTorch-Lightning Implementation of Self-Supervised Learning Methods
This is a [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) implementation of the following self-supervised representation learning methods:
- [MoCo](https://arxiv.org/abs/1911.05722)
- [MoCo v2](https://arxiv.org/abs/2003.04297)
- [SimCLR](https://arxiv.org/abs/2002.05709)
- [BYOL](https://arxiv.org/abs/2006.07733)
- [EqCo](https://arxiv.org/abs/2010.01929)
- [VICReg](https://arxiv.org/abs/2105.04906)
Supported datasets: ImageNet, STL-10, and CIFAR-10.
During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an `sklearn` linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as `train_class_acc` / `valid_class_acc`.
## Installing
Make sure you're in a fresh `conda` or `venv` environment, then run:
```bash
git clone https://github.com/untitled-ai/self_supervised
cd self_supervised
pip install -r requirements.txt
```
## Replicating our BYOL blog post
We found some surprising results about the role of batch norm in BYOL. See the blog post [Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL)](https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html) for more details about our experiments.
You can replicate the results of our blog post by running `python train_blog.py`. The cosine similarity between z and z' is reported as `step_neg_cos` (for negative examples) and `step_pos_cos` (for positive examples). Classification accuracy is reported as `valid_class_acc`.
## Getting started with MoCo v2
To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import ModelParams
os.environ["DATA_PATH"] = "~/data"
params = ModelParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
For convenience, you can instead pass these parameters as keyword args, for example with `model = SelfSupervisedMethod(batch_size=128)`.
## VICReg
To train VICReg rather than MoCo v2, use the following parameters:
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import VICRegParams
os.environ["DATA_PATH"] = "~/data"
params = VICRegParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.
## BYOL
To train BYOL rather than MoCo v2, use the following parameters:
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import BYOLParams
os.environ["DATA_PATH"] = "~/data"
params = BYOLParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
## SimCLR
To train SimCLR rather than MoCo v2, use the following parameters:
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import SimCLRParams
os.environ["DATA_PATH"] = "~/data"
params = SimCLRParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
**Note for multi-GPU setups**: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.
# Evaluating a trained model
To train a linear classifier on the result:
```python
import pytorch_lightning as pl
from linear_classifier import LinearClassifierMethod
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
trainer = pl.Trainer(gpus=1, max_epochs=100)
trainer.fit(linear_model)
```
# Results on STL-10 and ImageNet
Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.
Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set.
This used 8 gpus with `ddp` and parameters:
```python
hparams = ModelParams(
encoder_arch="resnet50",
shuffle_batch_norm=True,
embedding_dim=2048,
mlp_hidden_dim=2048,
dataset_name="imagenet",
batch_size=32,
lr=0.03,
max_epochs=200,
transform_crop_size=224,
num_data_workers=32,
gather_keys_for_queue=True,
)
```
(the `batch_size` differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu
training in `ddp` - the effective number is `batch_size=256`). **Note that for ImageNet we suggest using
`val_percent_check=0.1` when calling `pl.Trainer`** to reduce the time fitting the sklearn model.
# All training options
All possible `hparams` for SelfSupervisedMethod, along with defaults:
```python
class ModelParams:
# encoder model selection
encoder_arch: str = "resnet18"
shuffle_batch_norm: bool = False
embedding_dim: int = 512 # must match embedding dim of encoder
# data-related parameters
dataset_name: str = "stl10"
batch_size: int = 256
# MoCo parameters
K: int = 65536 # number of examples in queue
dim: int = 128
m: float = 0.996
T: float = 0.2
# eqco parameters
eqco_alpha: int = 65536
use_eqco_margin: bool = False
use_negative_examples_from_batch: bool = False
# optimization parameters
lr: float = 0.5
momentum: float = 0.9
weight_decay: float = 1e-4
max_epochs: int = 320
final_lr_schedule_value: float = 0.0
# transform parameters
transform_s: float = 0.5
transform_apply_blur: bool = True
# Change these to make more like BYOL
use_momentum_schedule: bool = False
loss_type: str = "ce"
use_negative_examples_from_queue: bool = True
use_both_augmentations_as_queries: bool = False
optimizer_name: str = "sgd"
lars_warmup_epochs: int = 1
lars_eta: float = 1e-3
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
loss_constant_factor: float = 1
# Change these to make more like VICReg
use_vicreg_loss: bool = False
use_lagging_model: bool = True
use_unit_sphere_projection: bool = True
invariance_loss_weight: float = 25.0
variance_loss_weight: float = 25.0
covariance_loss_weight: float = 1.0
variance_loss_epsilon: float = 1e-04
# MLP parameters
projection_mlp_layers: int = 2
prediction_mlp_layers: int = 0
mlp_hidden_dim: int = 512
mlp_normalization: Optional[str] = None
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
use_mlp_weight_standardization: bool = False
# data loader parameters
num_data_workers: int = 4
drop_last_batch: bool = True
pin_data_memory: bool = True
gather_keys_for_queue: bool = False
```
A few options require more explanation:
- **encoder_arch** can be any torchvision model, or can be one of the ResNet models with weight standardization defined in
`ws_resnet.py`.
- **dataset_name** can be `imagenet`, `stl10`, or `cifar10`. `os.environ["DATA_PATH"]` will be used as the path to the data. STL-10 and CIFAR-10 will
be downloaded if they do not already exist.
- **loss_type** can be `ce` (cross entropy) with one of the `use_negative_examples` to correspond to MoCo or `ip` (inner product)
with both `use_negative_examples=False` to correspond to BYOL. It can also be `bce`, which is similar to `ip` but applies the
binary cross entropy loss function to the result. Or it can be `vic` for VICReg loss.
- **optimizer_name**, currently just `sgd` or `lars`.
- **exclude_matching_parameters_from_lars** will remove weight decay and LARS learning rate from matching parameters. Set
to `[".bias", ".bn"]` to match BYOL paper implementation.
- **mlp_normalization** can be None for no normalization, `bn` for batch normalization, `ln` for layer norm, `gn` for group
norm, or `br` for [batch renormalization](https://github.com/ludvb/batchrenorm).
- **prediction_mlp_normalization** defaults to `same` to use the same normalization as above, but can be given any of the
above parameters to use a different normalization.
- **shuffle_batch_norm** and **gather_keys_for_queue** are both related to multi-gpu training. **shuffle_batch_norm**
will shuffle the *key* images among GPUs, which is needed for training if batch norm is used. **gather_keys_for_queue**
will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue.
# Training with custom options
You can train using any settings of the above parameters. This configuration represents the settings from BYOL:
```python
hparams = ModelParams(
prediction_mlp_layers=2,
mlp_normalization="bn",
loss_type="ip",
use_negative_examples_from_queue=False,
use_both_augmentations_as_queries=True,
use_momentum_schedule=True,
optimizer_name="lars",
exclude_matching_parameters_from_lars=[".bias", ".bn"],
loss_constant_factor=2
)
```
Or here is our recommended way to modify VICReg for CIFAR-10:
```python
from model_params import VICRegParams
hparams = VICRegParams(
dataset_name="cifar10",
transform_apply_blur=False,
mlp_hidden_dim=2048,
dim=2048,
batch_size=256,
lr=0.3,
final_lr_schedule_value=0,
weight_decay=1e-4,
lars_warmup_epochs=10,
lars_eta=0.02
)
```

View File

@ -0,0 +1,88 @@
"""
From https://github.com/ludvb/batchrenorm
@article{batchrenomalization,
author = {Sergey Ioffe},
title = {Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models},
journal = {arXiv preprint arXiv:1702.03275},
year = {2017},
}
"""
import torch
__all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"]
class BatchRenorm(torch.nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-3,
momentum: float = 0.01,
affine: bool = True,
):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float))
self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))
self.affine = affine
self.eps = eps
self.step = 0
self.momentum = momentum
def _check_input_dim(self, x: torch.Tensor) -> None:
raise NotImplementedError() # pragma: no cover
@property
def rmax(self) -> torch.Tensor:
return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0)
@property
def dmax(self) -> torch.Tensor:
return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self._check_input_dim(x)
if x.dim() > 2:
x = x.transpose(1, -1)
if self.training:
dims = [i for i in range(x.dim() - 1)]
batch_mean = x.mean(dims)
batch_std = x.std(dims, unbiased=False) + self.eps
r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(
1 / self.rmax.item(), self.rmax.item()
)
d = (
(batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std)
).clamp_(-self.dmax.item(), self.dmax.item())
x = (x - batch_mean) / batch_std * r + d
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
self.running_std += self.momentum * (batch_std.detach() - self.running_std)
self.num_batches_tracked += 1
else:
x = (x - self.running_mean) / self.running_std
if self.affine:
x = self.weight * x + self.bias
if x.dim() > 2:
x = x.transpose(1, -1)
return x
class BatchRenorm1d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() not in [2, 3]:
raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")
class BatchRenorm2d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() != 4:
raise ValueError("expected 4D input (got {x.dim()}D input)")
class BatchRenorm3d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() != 5:
raise ValueError("expected 5D input (got {x.dim()}D input)")

199
self_supervised/dataload.py Normal file
View File

@ -0,0 +1,199 @@
# for CustomImageDataset
import os
import glob
import random
import pandas as pd
import torch
from torchvision.io import read_image
from pathlib import Path
from PIL import Image
from pyts.image import GramianAngularField, MarkovTransitionField
# for transforms
from torchvision import transforms
import attr
# for CustomDataloader
from torch.utils.data import DataLoader
import numpy as np
# this recreates the ImageFolder function
class CustomImageDataset(torch.utils.data.Dataset):
def __init__(self, img_dir, transform=None, target_transform=None, img_size=400):
# self.img_labels = pd.read_csv(annotations_file)
self.img_labels = self.path_mapper(img_dir)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
self.gaf_function = GramianAngularField(image_size=img_size,
method="difference",
sample_range=(0,1))
self.mtf_function = MarkovTransitionField(image_size=img_size,
n_bins=5)
def path_mapper(self, img_dir):
df = pd.DataFrame()
img_path = Path(img_dir)
# grab all the label folders
dirs = [f for f in img_path.iterdir() if f.is_dir()]
for dir in dirs:
for f in Path.joinpath(img_path, dir).iterdir():
new_row = {'file': f.name, 'label': dir.name}
# we have the file name, and then the label of the folder it belongs to
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
return df
def __len__(self):
return len(self.img_labels)
def normalize(self, data):
data_normalize = ((data - data.min()) / (data.max() - data.min()))
return data_normalize
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir,
self.img_labels.iloc[idx,1],
self.img_labels.iloc[idx,0])
data = np.load(img_path)['data']
data = data.reshape((1,-1))
gaf_image = self.normalize(self.gaf_function.transform(data)[0])
mtf_image = gaf_image # to turn off mtf
# mtf_image = self.normalize(self.mtf_function.transform(data)[0])
image = torch.from_numpy((np.stack([gaf_image, mtf_image], axis=0)).astype(np.float32))
# assert image.dtype == torch.float32, "Tensor is not float32!"
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
@attr.s(auto_attribs=True)
class ImageTransforms:
img_size: int = 400
crop_size: tuple[int,int] = (224,224)
normalize_means: list = [0.0, 0.0]
normalize_stds: list = [1.0, 1.0]
def split_transform(self, img) -> torch.Tensor:
transform = self.single_transform()
return torch.stack((transform(img), transform(img)))
def single_transform(self):
transform_list = [
# transforms.ToTensor(),
transforms.RandomResizedCrop(self.crop_size,
scale=(0.3,0.7),
antialias=False),
# transforms.RandomCrop(size=(int(self.img_size * 0.4), int(self.img_size * 0.4)))
transforms.Normalize(mean=self.normalize_means, std=self.normalize_stds)
]
return transforms.Compose(transform_list)
class CustomDataloader():
def __init__(self,
img_dir: str,
img_size: int = 400,
batch_size: int = 64,
num_workers: int = 4,
persistent_workers: bool = True,
shuffle: bool = True
):
self.batch_size = batch_size
self.num_workers = num_workers
self.persistent_workers = persistent_workers
self.shuffle = shuffle
normalize_means, normalize_stds = normalize_params(img_dir).calculate_mean_std()
self.image_transforms = ImageTransforms(img_size=img_size,
crop_size=(224,224),
normalize_means=normalize_means,
normalize_stds=normalize_stds)
self.dataset = CustomImageDataset(img_dir=img_dir,
img_size=img_size,
transform=self.image_transforms.split_transform)
def get_dataloader(self):
return DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
shuffle=self.shuffle)
class normalize_params():
def __init__(self, root_dir):
self.root_dir = root_dir
self.img_size = len(np.load(self.get_random_file())['data'])
self.gaf_function = GramianAngularField(image_size=self.img_size,
method="difference",
sample_range=(0,1))
self.mtf_function = MarkovTransitionField(image_size=self.img_size,
n_bins = 5)
def get_random_file(self):
# Search for all files in the directory and subdirectories
file_list = glob.glob(os.path.join(self.root_dir, '**', '*'), recursive=True)
# Filter out directories from the list
file_list = [f for f in file_list if os.path.isfile(f)]
# If there are no files found, return None or raise an exception
if not file_list:
raise FileNotFoundError("No files found in the specified directory")
# Select and return a random file path
return random.choice(file_list)
def normalize(self, data):
data_normalize = ((data - data.min()) / (data.max() - data.min()))
return data_normalize
def load_image(self, filepath):
data = np.load(filepath)['data'].astype(np.float32)
data = data.reshape((1,-1))
gaf_image = self.gaf_function.transform(data)[0]
mtf_image = gaf_image
# mtf_image = self.mtf_function.transform(data)[0]
image = (np.stack([gaf_image, mtf_image], axis=0)).astype(np.float32)
return image
def calculate_mean_std(self):
# Initialize lists to store the sum and squared sum of pixel values
mean_1, mean_2 = 0.0, 0.0
std_1, std_2 = 0.0, 0.0
num_pixels = 0
image_dir = self.root_dir
# Iterate through all images in the directory
for dirpath, dirnames, filenames in os.walk(image_dir):
for filename in filenames:
# Full path of the file
file_path = os.path.join(dirpath, filename)
if os.path.isfile(file_path) and file_path.endswith(('npz')):
img_np = self.load_image(file_path)
# img_np = np.array(img) / 255.0 # Normalize to range [0, 1]
num_pixels += img_np.shape[1] * img_np.shape[2]
mean_1 += np.sum(img_np[0, :, :])
mean_2 += np.sum(img_np[1, :, :])
std_1 += np.sum(img_np[0, :, :] ** 2)
std_2 += np.sum(img_np[1, :, :] ** 2)
# Calculate mean
mean_1 /= num_pixels
mean_2 /= num_pixels
# Calculate standard deviation
std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5
std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5
return [mean_1, mean_2], [std_1, std_2]

113
self_supervised/lars.py Normal file
View File

@ -0,0 +1,113 @@
"""
Layer-wise adaptive rate scaling for SGD in PyTorch!
Based on https://github.com/noahgolmant/pytorch-lars
"""
import torch
from torch.optim.optimizer import Optimizer
class LARS(Optimizer):
r"""Implements layer-wise adaptive rate scaling for SGD.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): base learning rate (\gamma_0)
momentum (float, optional): momentum factor (default: 0) ("m")
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
("\beta")
eta (float, optional): LARS coefficient
max_epoch: maximum training epoch to determine polynomial LR decay.
Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
Large Batch Training of Convolutional Networks:
https://arxiv.org/abs/1708.03888
Example:
>>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
"""
def __init__(self, params, lr=1.0, momentum=0.9, weight_decay=0.0005, eta=0.001, max_epoch=200, warmup_epochs=1):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if eta < 0.0:
raise ValueError("Invalid LARS coefficient value: {}".format(eta))
self.epoch = 0
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
eta=eta,
max_epoch=max_epoch,
warmup_epochs=warmup_epochs,
use_lars=True,
)
super().__init__(params, defaults)
def step(self, epoch=None, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
epoch: current epoch to calculate polynomial LR decay schedule.
if None, uses self.epoch and increments it.
"""
loss = None
if closure is not None:
loss = closure()
if epoch is None:
epoch = self.epoch
self.epoch += 1
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
eta = group["eta"]
lr = group["lr"]
warmup_epochs = group["warmup_epochs"]
use_lars = group["use_lars"]
group["lars_lrs"] = []
for p in group["params"]:
if p.grad is None:
continue
param_state = self.state[p]
d_p = p.grad.data
weight_norm = torch.norm(p.data)
grad_norm = torch.norm(d_p)
# Global LR computed on polynomial decay schedule
warmup = min((1 + float(epoch)) / warmup_epochs, 1)
global_lr = lr * warmup
# Update the momentum term
if use_lars:
# Compute local learning rate for this layer
local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm)
actual_lr = local_lr * global_lr
group["lars_lrs"].append(actual_lr.item())
else:
actual_lr = global_lr
group["lars_lrs"].append(global_lr)
if "momentum_buffer" not in param_state:
buf = param_state["momentum_buffer"] = torch.zeros_like(p.data)
else:
buf = param_state["momentum_buffer"]
buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr)
p.data.add_(-buf)
return loss

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,80 @@
from attr import evolve
# from lightning.pytorch.callbacks import ModelCheckpoint
# from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from vicreg import SelfSupervisedMethod
from model_params import VICRegParams
# import utils
# from torch.utils.data import DataLoader
from dataload import CustomDataloader
import torch
import numpy as np
from utils import get_random_file
# speedup
torch.set_float32_matmul_precision('medium')
# data parameters
data_params = list()
data_params.append({'train_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/train",
'test_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/test",
'checkpoint': 'checkpoint_semes/last.ckpt'})
batch_size = 256
num_epochs=20
selector = 0
def main():
configs = {
"vicreg": evolve(VICRegParams(),
encoder_arch = "ws_resnet18", # resnet18, resnet34, resnet50
max_epochs=num_epochs
),
}
train_path = data_params[selector]['train_path']
random_file = get_random_file(train_path)
img_size = len(np.load(random_file)['data'])
for seed in range(1): # number of repeats
for name, config in configs.items():
method = SelfSupervisedMethod(config)
# logger = TensorBoardLogger("tb_logs", name=f"{name}_{seed}")
# Define the checkpoint callback to save the best model only at the end of training
# checkpoint_callback = ModelCheckpoint(
# filename=f'last-v{seed}',
# dirpath='checkpoint_semes', # Directory to save the checkpoints
# every_n_epochs=10,
# # every_n_train_steps=2,
# save_last=True, # Save the last model
# save_top_k=1,
# save_weights_only=True, # Only save the model weights (not the optimizer, etc.)
# # save_on_train_epoch_end=True # Save only at the end of the training epoch
# )
trainer = Trainer(accelerator="gpu",
devices=[1],
max_epochs=num_epochs)
# strategy=DDPStrategy(find_unused_parameters=False)) # to enable multi-gpu, but not necessary for now
print("--------------------------------------")
print(data_params[selector]['checkpoint'])
train_loader = CustomDataloader(img_dir=data_params[selector]['train_path'],
img_size=img_size,
batch_size=batch_size,
).get_dataloader()
trainer.fit(model=method, train_dataloaders=train_loader)
trainer.save_checkpoint(data_params[selector]['checkpoint'])
if __name__ == "__main__":
main()

View File

@ -0,0 +1,63 @@
from functools import partial
from typing import List
from typing import Optional
import attr
@attr.s(auto_attribs=True)
class ModelParams:
# encoder model selection
encoder_arch: str = "ws_resnet18" # resnet18, resnet34, resnet50
# note that embedding_dim is 512 * expansion parameter in ws_resnet
embedding_dim: int = 512 * 1 # must match embedding dim of encoder
# projection size
dim: int = 64
# optimization parameters
optimizer_name: str = 'lars'
lr: float = 0.5
momentum: float = 0.9
weight_decay: float = 1e-4
max_epochs: int = 10
final_lr_schedule_value: float = 0.0
lars_warmup_epochs: int = 1
lars_eta: float = 1e-3
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
# loss parameters
loss_constant_factor: float = 1
invariance_loss_weight: float = 25.0
variance_loss_weight: float = 25.0
covariance_loss_weight: float = 1.0
variance_loss_epsilon: float = 1e-04
kmeans_weight: float = 1e-03
# MLP parameters
projection_mlp_layers: int = 2
prediction_mlp_layers: int = 0 # by default prediction mlp is identity
mlp_hidden_dim: int = 512
mlp_normalization: Optional[str] = None
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
use_mlp_weight_standardization: bool = False
# Differences between these parameters and those used in the paper (on image net):
# max_epochs=1000,
# lr=1.6,
# batch_size=2048,
# weight_decay=1e-6,
# mlp_hidden_dim=8192,
# dim=8192,
VICRegParams = partial(
ModelParams,
exclude_matching_parameters_from_lars=[".bias", ".bn"],
projection_mlp_layers=3,
final_lr_schedule_value=0.002,
mlp_normalization="bn",
lars_warmup_epochs=2,
kmeans_weight=1e-03,
)

View File

@ -0,0 +1,132 @@
aiohappyeyeballs==2.3.5
aiohttp==3.10.3
aiosignal==1.3.1
anyio==4.4.0
arrow==1.3.0
asttokens==2.4.1
attrs==24.2.0
beautifulsoup4==4.12.3
blessed==1.20.0
boto3==1.34.159
botocore==1.34.159
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.2
contourpy==1.2.1
croniter==1.3.15
cycler==0.12.1
dateutils==0.6.12
debugpy==1.8.5
decorator==5.1.1
deepdiff==7.0.1
editor==1.6.6
executing==2.0.1
fastapi==0.88.0
fastjsonschema==2.20.0
filelock==3.15.4
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2023.12.2
h11==0.14.0
idna==3.7
inquirer==3.4.0
ipykernel==6.29.5
ipython==8.26.0
itsdangerous==2.2.0
jedi==0.19.1
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client==8.6.2
jupyter_core==5.7.2
kiwisolver==1.4.5
lightning==1.9.5
lightning-cloud==0.5.70
lightning-utilities==0.11.6
llvmlite==0.43.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
numba==0.60.0
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
ordered-set==4.1.0
packaging==24.1
pandas==2.2.2
parso==0.8.4
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.2.2
plotly==5.23.0
prompt_toolkit==3.0.47
protobuf==5.27.3
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pydantic==1.10.17
Pygments==2.18.0
PyJWT==2.9.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-multipart==0.0.9
pytorch-lightning==1.9.5
pyts==0.13.0
pytz==2024.1
PyYAML==6.0.2
pyzmq==26.1.0
readchar==4.2.0
referencing==0.35.1
requests==2.32.3
rich==13.7.1
rpds-py==0.20.0
runs==1.2.2
s3transfer==0.10.2
scikit-learn==1.5.1
scipy==1.14.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
stack-data==0.6.3
starlette==0.22.0
starsessions==1.3.0
sympy==1.13.2
tenacity==9.0.0
threadpoolctl==3.5.0
torch==2.4.0
torchmetrics==1.4.1
torchvision==0.19.0
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
triton==3.0.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
uvicorn==0.30.5
wcwidth==0.2.13
websocket-client==1.8.0
websockets==11.0.3
xmod==1.8.1
yarl==1.9.4

View File

@ -0,0 +1,388 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"import glob, os\n",
"import numpy as np\n",
"\n",
"import utils\n",
"from moco import SelfSupervisedMethod\n",
"# from model_params import EigRegParams\n",
"from model_params import VICRegParams\n",
"\n",
"from attr import evolve\n",
"\n",
"import pandas as pd\n",
"from sklearn.decomposition import PCA\n",
"\n",
"from sklearn.cluster import KMeans\n",
"from sklearn.metrics import rand_score, normalized_mutual_info_score\n",
"\n",
"# data parameters\n",
"data_params = list()\n",
"# 0 Beef\n",
"data_params.append({'resize': 471, \n",
" 'batch_size':30,\n",
" 'num_clusters': 5,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/test\",\n",
" 'checkpoint': 'checkpoint_beef'})\n",
"# 1 dist.phal.outl.agegroup\n",
"data_params.append({'resize': 81, \n",
" 'batch_size':139,\n",
" 'num_clusters': 3,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/test\",\n",
" 'checkpoint': 'checkpoint_dist_agegroup'})\n",
"# 2 ECG200\n",
"data_params.append({'resize': 97, \n",
" 'batch_size':100,\n",
" 'num_clusters': 2,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/test\",\n",
" 'checkpoint': 'checkpoint_ecg200'})\n",
"# 3 ECGFiveDays\n",
"data_params.append({'resize': 137, \n",
" 'batch_size':23,\n",
" 'num_clusters': 2,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/test\",\n",
" 'checkpoint': 'checkpoint_ecg5days'})\n",
"# 4 Meat\n",
"data_params.append({'resize': 449, \n",
" 'batch_size':60,\n",
" 'num_clusters': 3,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/test\",\n",
" 'checkpoint': 'checkpoint_meat'})\n",
"# 5 mote strain\n",
"data_params.append({'resize': 85, \n",
" 'batch_size': 20,\n",
" 'num_clusters': 2,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/test\",\n",
" 'checkpoint': 'checkpoint_motestrain'})\n",
"# 6 osuleaf\n",
"data_params.append({'resize': 428, \n",
" 'batch_size': 64, # 200\n",
" 'num_clusters': 6,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/test\",\n",
" 'checkpoint': 'checkpoint_osuleaf'})\n",
"# 7 plane\n",
"data_params.append({'resize': 145, \n",
" 'batch_size': 105,\n",
" 'num_clusters': 7,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/test\",\n",
" 'checkpoint': 'checkpoint_plane'})\n",
"# 8 proximal_agegroup\n",
"data_params.append({'resize': 81, \n",
" 'batch_size': 205,\n",
" 'num_clusters': 3,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/test\",\n",
" 'checkpoint': 'checkpoint_prox_agegroup'})\n",
"# 9 proximal_tw\n",
"data_params.append({'resize': 81, \n",
" 'batch_size': 100, # 400\n",
" 'num_clusters': 6,\n",
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/train\",\n",
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/test\",\n",
" 'checkpoint': 'checkpoint_prox_tw'})\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def calculate_mean_std(image_dir):\n",
" # Initialize lists to store the sum and squared sum of pixel values\n",
" mean_1, mean_2 = 0.0, 0.0\n",
" std_1, std_2 = 0.0, 0.0\n",
" num_pixels = 0\n",
"\n",
" # Iterate through all images in the directory\n",
" for dirpath, dirnames, filenames in os.walk(image_dir):\n",
" for filename in filenames:\n",
" # Full path of the file\n",
" file_path = os.path.join(dirpath, filename)\n",
"\n",
" # for img_name in os.listdir(image_dir):\n",
" # img_path = os.path.join(image_dir, img_name)\n",
" if os.path.isfile(file_path) and file_path.endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff')):\n",
" with Image.open(file_path) as img:\n",
" # img = img.convert('RGB') # Ensure image is in RGB format\n",
" img_np = np.array(img) / 255.0 # Normalize to range [0, 1]\n",
" \n",
" num_pixels += img_np.shape[0] * img_np.shape[1]\n",
" \n",
" mean_1 += np.sum(img_np[:, :, 0])\n",
" mean_2 += np.sum(img_np[:, :, 1])\n",
" \n",
" std_1 += np.sum(img_np[:, :, 0] ** 2)\n",
" std_2 += np.sum(img_np[:, :, 1] ** 2)\n",
"\n",
" # Calculate mean\n",
" mean_1 /= num_pixels\n",
" mean_2 /= num_pixels\n",
"\n",
" # Calculate standard deviation\n",
" std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5\n",
" std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5\n",
"\n",
" return [mean_1, mean_2], [std_1, std_2]\n",
"\n",
"def list_directories(path):\n",
" entries = os.listdir(path)\n",
" directories = [ entry for entry in entries if os.path.isdir(os.path.join(path, entry))]\n",
" return directories\n",
"\n",
"\n",
"def inference(method, classes, transform, path):\n",
" batch_size = 32\n",
" image_tensors = []\n",
" result = []\n",
" labels = []\n",
"\n",
" # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" # device = torch.device('cpu')\n",
" device = torch.device('cuda:2')\n",
" method.model.to(device)\n",
" method.projection_model.to(device)\n",
"\n",
"\n",
" for key in classes:\n",
" image_dir = path + '/' + key \n",
" for img_name in os.listdir(image_dir):\n",
" image_path = os.path.join(image_dir, img_name)\n",
" image = Image.open(image_path)\n",
" # image = image.convert('RGB')\n",
"\n",
" # Preprocess the image\n",
" input_tensor = transform(image).unsqueeze(0) # Add batch dimension\n",
" image_tensors.append(input_tensor)\n",
"\n",
" # perform batching\n",
" if len(image_tensors) == batch_size:\n",
" batch_tensor = torch.cat(image_tensors).to(device)\n",
" # Use the pre-trained model to extract features\n",
" with torch.no_grad():\n",
" emb = method.model(batch_tensor)\n",
" projection = method.projection_model(emb)\n",
" # projection = method.model(input_tensor)\n",
" result.extend(projection.cpu())\n",
" # reset back to 0\n",
" image_tensors = []\n",
"\n",
"\n",
" labels.append(int(key))\n",
"\n",
" if len(image_tensors) > 0:\n",
" batch_tensor = torch.cat(image_tensors).to(device)\n",
" # Use the pre-trained model to extract features\n",
" with torch.no_grad():\n",
" emb = method.model(batch_tensor)\n",
" projection = method.projection_model(emb)\n",
" # projection = method.model(input_tensor)\n",
" result.extend(projection.cpu())\n",
"\n",
" return result, labels\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Number of runs\n",
"num_runs = 10\n",
"# Number of results/metrics per run\n",
"num_results_per_run = 10\n",
"# Create a 2D NumPy array to store the results\n",
"ri_results = np.zeros((num_runs, num_results_per_run))\n",
"nmi_results = np.zeros((num_runs, num_results_per_run))\n",
"\n",
"\n",
"start = 0\n",
"end = 9\n",
"for run_num in range(num_runs):\n",
" for selector in range(start,end+1):\n",
"\n",
" config = evolve(VICRegParams(), \n",
" encoder_arch = \"ws_resnet18\", # resnet18, resnet34, resnet50\n",
" dataset_name=\"custom\", \n",
" train_path=data_params[selector]['train_path'],\n",
" test_path=data_params[selector]['test_path'],\n",
" kmeans_weight=0, # it doens't matter since this is not used in the model\n",
" num_clusters=data_params[selector]['num_clusters'])\n",
" method = SelfSupervisedMethod(config)\n",
" # Initialize your ResNet model\n",
" checkpoint = data_params[selector]['checkpoint']\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=49-step=50.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=99-step=100.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=149-step=150.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=199-step=200.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=299-step=300.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=399-step=400.ckpt'\n",
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=499-step=500.ckpt'\n",
" path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/last-v{run_num}.ckpt'\n",
" method = method.load_from_checkpoint(path)\n",
" # Set the model to evaluation mode\n",
" method.eval()\n",
"\n",
"\n",
"\n",
" # Define transform\n",
" path = data_params[selector]['test_path']\n",
" normalize_means, normalize_stds = calculate_mean_std(path)\n",
" # image_size = data_params[selector]['resize']\n",
" # crop_size = int(0.4 * image_size)\n",
" transform = transforms.Compose([\n",
" # transforms.Resize((image_size, image_size)),\n",
"\n",
" # transforms.CenterCrop(size=(crop_size, crop_size)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=normalize_means, std=normalize_stds),\n",
" ])\n",
"\n",
"\n",
"\n",
" # get all the classes\n",
" classes = list_directories(path)\n",
"\n",
" result, labels = inference(method, classes, transform, path)\n",
"\n",
" data = np.array(result)\n",
" # pca = PCA(n_components=2)\n",
" # reduced_data = pca.fit_transform(data)\n",
"\n",
" # Choose the number of clusters, say 3\n",
" kmeans = KMeans(n_clusters=data_params[selector]['num_clusters'], random_state=42, n_init=10)\n",
" clusters = kmeans.fit_predict(data)\n",
"\n",
"\n",
" # print(data_params[selector]['checkpoint'])\n",
" # print(\"Rand Index: \", rand_score(clusters, labels))\n",
" # print(\"NMI: \", normalized_mutual_info_score(clusters, labels))\n",
" rand_index = rand_score(clusters, labels)\n",
" nmi = normalized_mutual_info_score(clusters, labels)\n",
"\n",
" ri_results[run_num,selector] = rand_index\n",
" nmi_results[run_num, selector] = nmi"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"RI mean: 0.6374712643678162\n",
"RI std: 0.03729597247277404\n",
"NMI mean: 0.2602497793608286\n",
"NMI std: 0.02985710765468422\n",
"1\n",
"RI mean: 0.6458659159628819\n",
"RI std: 0.01396857176725652\n",
"NMI mean: 0.30136574266475413\n",
"NMI std: 0.03170806441677236\n",
"2\n",
"RI mean: 0.6408888888888888\n",
"RI std: 0.021672013185081274\n",
"NMI mean: 0.20592658284184645\n",
"NMI std: 0.048248741710938625\n",
"3\n",
"RI mean: 0.5765183804661967\n",
"RI std: 0.03417303145285537\n",
"NMI mean: 0.11716054993167128\n",
"NMI std: 0.05173307476499848\n",
"4\n",
"RI mean: 0.7636723163841806\n",
"RI std: 0.0838674066635877\n",
"NMI mean: 0.6087294263576666\n",
"NMI std: 0.10910741463199608\n",
"5\n",
"RI mean: 0.6088675385570139\n",
"RI std: 0.041236238284731705\n",
"NMI mean: 0.17859344790373669\n",
"NMI std: 0.08743358257833596\n",
"6\n",
"RI mean: 0.7343060937553582\n",
"RI std: 0.020174409290336055\n",
"NMI mean: 0.22234048756150684\n",
"NMI std: 0.029705953611425088\n",
"7\n",
"RI mean: 0.9384065934065934\n",
"RI std: 0.019608200939834567\n",
"NMI mean: 0.8406540399495203\n",
"NMI std: 0.03016596675386891\n",
"8\n",
"RI mean: 0.7505643232902918\n",
"RI std: 0.022756611198669806\n",
"NMI mean: 0.48619057929963566\n",
"NMI std: 0.01338604938860034\n",
"9\n",
"RI mean: 0.832415112386418\n",
"RI std: 0.019248159640681852\n",
"NMI mean: 0.5497811436968876\n",
"NMI std: 0.023003346601586715\n"
]
}
],
"source": [
"for data_select in range(10):\n",
"\tprint(data_select)\n",
"\tprint(\"RI mean: \", np.mean(ri_results[:,data_select]))\n",
"\tprint(\"RI std: \", np.std(ri_results[:,data_select]))\n",
"\tprint(\"NMI mean: \", np.mean(nmi_results[:,data_select]))\n",
"\tprint(\"NMI std: \", np.std(nmi_results[:,data_select]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,93 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"4\n",
"NVIDIA GeForce RTX 4070\n"
]
}
],
"source": [
"import torch\n",
"\n",
"print(torch.cuda.is_available()) # Should return True if CUDA is available\n",
"print(torch.cuda.device_count()) # Should return the number of GPUs available\n",
"print(torch.cuda.get_device_name(0)) # Should return the name of the first GPU, if available\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_2643050/1394991573.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(checkpoint_path)\n"
]
}
],
"source": [
"import torch\n",
"\n",
"checkpoint_path = 'checkpoint_semes/last.ckpt'\n",
"checkpoint = torch.load(checkpoint_path)\n",
"print(checkpoint.keys()) # Check what is saved in the checkpoint"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"checkpoint['hparams_name']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

199
self_supervised/utils.py Normal file
View File

@ -0,0 +1,199 @@
# import os
# import random
# from typing import Any
# from typing import Callable
# from typing import Optional
import pandas as pd
import glob
import os
import random
# import attr
import torch
import torchvision
import ws_resnet
# from model_params import ModelParams
################
# main train utils
def get_random_file(root_dir):
# Search for all files in the directory and subdirectories
file_list = glob.glob(os.path.join(root_dir, '**', '*'), recursive=True)
# Filter out directories from the list
file_list = [f for f in file_list if os.path.isfile(f)]
# If there are no files found, return None or raise an exception
if not file_list:
raise FileNotFoundError("No files found in the specified directory")
# Select and return a random file path
return random.choice(file_list)
#####################
# Parallelism utils #
#####################
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
class BatchShuffleDDP:
@staticmethod
@torch.no_grad()
def shuffle(x):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).to(x.device)
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@staticmethod
@torch.no_grad()
def unshuffle(x, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]
###############
# Model utils #
###############
class MLP(torch.nn.Module):
def __init__(
self, input_dim, output_dim, hidden_dim, num_layers, weight_standardization=False, normalization=None
):
super().__init__()
assert num_layers >= 0, "negative layers?!?"
if normalization is not None:
assert callable(normalization), "normalization must be callable"
if num_layers == 0:
self.net = torch.nn.Identity()
return
if num_layers == 1:
self.net = torch.nn.Linear(input_dim, output_dim)
return
linear_net = ws_resnet.Linear if weight_standardization else torch.nn.Linear
layers = []
prev_dim = input_dim
for _ in range(num_layers - 1):
layers.append(linear_net(prev_dim, hidden_dim))
if normalization is not None:
layers.append(normalization())
layers.append(torch.nn.ReLU())
prev_dim = hidden_dim
layers.append(torch.nn.Linear(hidden_dim, output_dim))
self.net = torch.nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def get_encoder(name: str, **kwargs) -> torch.nn.Module:
"""
Gets just the encoder portion of a torchvision model (replaces final layer with identity)
:param name: (str) name of the model
:param kwargs: kwargs to send to the model
:return:
"""
if name in ws_resnet.__dict__:
model_creator = ws_resnet.__dict__.get(name)
elif name in torchvision.models.__dict__:
model_creator = torchvision.models.__dict__.get(name)
else:
raise AttributeError(f"Unknown architecture {name}")
assert model_creator is not None, f"no torchvision model named {name}"
model = model_creator(**kwargs)
if hasattr(model, "fc"): # in resnet
model.fc = torch.nn.Identity()
elif hasattr(model, "classifier"): # not in resnet
model.classifier = torch.nn.Identity()
else:
raise NotImplementedError(f"Unknown class {model.__class__}")
return model
####################
# Evaluation utils #
####################
def calculate_accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def log_softmax_with_factors(logits: torch.Tensor, log_factor: float = 1, neg_factor: float = 1) -> torch.Tensor:
exp_sum_neg_logits = torch.exp(logits).sum(dim=-1, keepdim=True) - torch.exp(logits)
softmax_result = logits - log_factor * torch.log(torch.exp(logits) + neg_factor * exp_sum_neg_logits)
return softmax_result

285
self_supervised/vicreg.py Normal file
View File

@ -0,0 +1,285 @@
import copy
import math
import warnings
from functools import partial
from typing import Optional
from typing import Union
import attr
import torch
import torch.nn.functional as F
# from pytorch_lightning.utilities import AttributeDict
from lightning.fabric.utilities.data import AttributeDict
from torch.utils.data import DataLoader
import lightning as L
import utils
from batchrenorm import BatchRenorm1d
from lars import LARS
from model_params import ModelParams
# from sklearn.linear_model import LogisticRegression
# from sklearn.cluster import KMeans
# from sklearn.metrics import rand_score, normalized_mutual_info_score
import pandas as pd
def get_mlp_normalization(hparams: ModelParams, prediction=False):
normalization_str = hparams.mlp_normalization
if prediction and hparams.prediction_mlp_normalization != "same":
normalization_str = hparams.prediction_mlp_normalization
if normalization_str is None:
return None
elif normalization_str == "bn":
return partial(torch.nn.BatchNorm1d, num_features=hparams.mlp_hidden_dim)
elif normalization_str == "br":
return partial(BatchRenorm1d, num_features=hparams.mlp_hidden_dim)
elif normalization_str == "ln":
return partial(torch.nn.LayerNorm, normalized_shape=[hparams.mlp_hidden_dim])
elif normalization_str == "gn":
return partial(torch.nn.GroupNorm, num_channels=hparams.mlp_hidden_dim, num_groups=32)
else:
raise NotImplementedError(f"mlp normalization {normalization_str} not implemented")
# class KMeansLoss:
#
# def __init__(self, num_clusters, embedding_dim, device):
# self.num_clusters = num_clusters
# self.centroids = torch.randn(num_clusters, embedding_dim, device=device)
# self.device=device
#
# def update_centroids(self, embeddings, assignments):
# for i in range(self.num_clusters):
# assigned_embeddings = embeddings[assignments == i]
# if len(assigned_embeddings) > 1: # good if more than singleton
# # implement ewma update for centroids
# weight1 = torch.tensor(0.3, device='cpu')
# weight2 = torch.tensor(0.7, device='cpu') # give more weight to new embeddings
# self.centroids[i] = self.centroids[i] * weight1 + assigned_embeddings.mean(dim=0).cpu() * weight2
#
# def set_centroids(self, embeddings, assignments):
# for i in range(self.num_clusters):
# assigned_embeddings = embeddings[assignments == i]
# if len(assigned_embeddings) > 1: # good if more than singleton
# # implement ewma update for centroids
# self.centroids[i] = assigned_embeddings.mean(dim=0).cpu()
#
#
# def compute_loss(self, embeddings):
# # move centroids to same device as embeddings
# centroids = self.centroids.to(embeddings.device)
# distances = torch.cdist(embeddings, centroids, p=self.num_clusters)
# min_distances, assignments = distances.min(dim=1)
# loss = min_distances.pow(2).sum()
# return loss, assignments
#
# def forward(self, embeddings, step_count):
# loss, assignments = self.compute_loss(embeddings)
# detached_embeddings = embeddings.detach()
# detached_assignments = assignments.detach()
#
# if (step_count < 5):
# self.set_centroids(detached_embeddings, detached_assignments)
# if (step_count % 2 == 0):
# self.update_centroids(detached_embeddings, detached_assignments)
# return loss
class SelfSupervisedMethod(L.LightningModule):
model: torch.nn.Module
hparams: AttributeDict
embedding_dim: Optional[int]
def __init__(
self,
hparams: Union[ModelParams, dict, None] = None,
**kwargs,
):
super().__init__()
# disable automatic optimization for lightning2
self.automatic_optimization = False
self.optimizer = None
self.lr_scheduler = None
# load from arguments
if hparams is None:
hparams = self.params(**kwargs)
# if it is already an attributedict, then use it directly
if hparams is not None:
self.save_hyperparameters(attr.asdict(hparams))
# Create encoder model
self.model = utils.get_encoder(hparams.encoder_arch)
# projection_mlp_layers = 3
self.projection_model = utils.MLP(
hparams.embedding_dim,
hparams.dim,
hparams.mlp_hidden_dim,
num_layers=hparams.projection_mlp_layers,
normalization=get_mlp_normalization(hparams),
weight_standardization=hparams.use_mlp_weight_standardization,
)
# by default it is identity
# prediction_mlp_layers = 0
self.prediction_model = utils.MLP(
hparams.dim,
hparams.dim,
hparams.mlp_hidden_dim,
num_layers=hparams.prediction_mlp_layers,
normalization=get_mlp_normalization(hparams, prediction=True),
weight_standardization=hparams.use_mlp_weight_standardization,
)
# kmeans loss
# self.kmeans_loss = KMeansLoss(num_clusters=hparams.num_clusters, embedding_dim=hparams.dim, device=self.device)
def _get_embeddings(self, x):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
embeddings, targets
"""
bsz, nd, nc, nh, nw = x.shape
assert nd == 2, "second dimension should be the split image -- dims should be N2CHW"
im_q = x[:, 0].contiguous()
im_k = x[:, 1].contiguous()
# compute query features
emb_q = self.model(im_q)
q_projection = self.projection_model(emb_q)
# by default vicreg gives an identity for prediction model
q = self.prediction_model(q_projection) # queries: NxC
emb_k = self.model(im_k)
k_projection = self.projection_model(emb_k)
k = self.prediction_model(k_projection) # queries: NxC
# q and k are the projection embeddings
return emb_q, q, k
def _get_vicreg_loss(self, z_a, z_b, batch_idx):
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
# invariance loss
loss_inv = F.mse_loss(z_a, z_b)
# variance loss
std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon)
std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon)
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
loss_v_b = torch.mean(F.relu(1 - std_z_b))
loss_var = loss_v_a + loss_v_b
# covariance loss
N, D = z_a.shape
z_a = z_a - z_a.mean(dim=0)
z_b = z_b - z_b.mean(dim=0)
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
loss_cov = loss_c_a + loss_c_b
weighted_inv = loss_inv * self.hparams.invariance_loss_weight
weighted_var = loss_var * self.hparams.variance_loss_weight
weighted_cov = loss_cov * self.hparams.covariance_loss_weight
loss = weighted_inv + weighted_var + weighted_cov
return loss
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, class_labels = batch # batch is a tuple, we just want the image
emb_q, q, k = self._get_embeddings(x)
vicreg_loss = self._get_vicreg_loss(q, k, batch_idx)
total_loss = vicreg_loss.mean() * self.hparams.loss_constant_factor
# here lies the manual optimizing code
self.optimizer.zero_grad()
self.manual_backward(total_loss)
self.optimizer.step()
self.lr_scheduler.step()
log_data = {
"step_train_loss": total_loss,
}
self.log_dict(log_data, sync_dist=True, prog_bar=True)
return {"loss": total_loss}
def configure_optimizers(self):
# exclude bias and batch norm from LARS and weight decay
regular_parameters = []
regular_parameter_names = []
excluded_parameters = []
excluded_parameter_names = []
for name, parameter in self.named_parameters():
if parameter.requires_grad is False:
continue
# for vicreg
# exclude_matching_parameters_from_lars=[".bias", ".bn"],
if any(x in name for x in self.hparams.exclude_matching_parameters_from_lars):
excluded_parameters.append(parameter)
excluded_parameter_names.append(name)
else:
regular_parameters.append(parameter)
regular_parameter_names.append(name)
param_groups = [
{
"params": regular_parameters,
"names": regular_parameter_names,
"use_lars": True
},
{
"params": excluded_parameters,
"names": excluded_parameter_names,
"use_lars": False,
"weight_decay": 0,
},
]
if self.hparams.optimizer_name == "sgd":
optimizer = torch.optim.SGD
elif self.hparams.optimizer_name == "lars":
optimizer = partial(LARS, warmup_epochs=self.hparams.lars_warmup_epochs, eta=self.hparams.lars_eta)
else:
raise NotImplementedError(f"No such optimizer {self.hparams.optimizer_name}")
self.optimizer = optimizer(
param_groups,
lr=self.hparams.lr,
momentum=self.hparams.momentum,
weight_decay=self.hparams.weight_decay,
)
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
self.hparams.max_epochs,
eta_min=self.hparams.final_lr_schedule_value,
)
return None # [encoding_optimizer], [self.lr_scheduler]
@classmethod
def params(cls, **kwargs) -> ModelParams:
return ModelParams(**kwargs)

View File

@ -0,0 +1,244 @@
"""
From https://github.com/joe-siyuan-qiao/pytorch-classification
@article{weightstandardization,
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
title = {Weight Standardization},
journal = {arXiv preprint arXiv:1903.10520},
year = {2019},
}
"""
import torch.nn as nn
import torch
from torch.nn import functional as F
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
def forward(self, x):
# return super(Conv2d, self).forward(x)
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class Linear(nn.Linear):
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True)
weight = weight - weight_mean
std = weight.std(dim=1, keepdim=True) + 1e-5
weight = weight / std.expand_as(weight)
return F.linear(x, weight, self.bias)
def BatchNorm2d(num_features):
return nn.GroupNorm(num_channels=num_features, num_groups=32)
__all__ = ["ws_resnet18", "ws_resnet34", "ws_resnet50", "ws_resnet101", "ws_resnet152"]
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4 # so that it can run
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
# print(f"Input shape: {x.shape}")
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# print(f"After conv1: {out.shape}")
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
# print(f"After conv2: {out.shape}")
out = self.conv3(out)
out = self.bn3(out)
# print(f"After conv3: {out.shape}")
if self.downsample is not None:
identity = self.downsample(x)
# print(f"After downsample: {identity.shape}")
# assert out.shape == identity.shape, f"Shape mismatch: out {out.shape}, identity {identity.shape}"
out += identity
out = self.relu(out)
# print(f"Output shape: {out.shape}")
return out
chan_size = 2
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
super(ResNet, self).__init__()
self.inplanes = 64
self.conv1 = Conv2d(in_channels=chan_size, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if (stride != 1) or (self.inplanes != planes * block.expansion):
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def ws_resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
def ws_resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
return model
def ws_resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def ws_resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def ws_resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model