Feat: semes gaf time series classifier
This commit is contained in:
commit
14405c7285
|
@ -0,0 +1 @@
|
|||
testlog*
|
|
@ -0,0 +1,2 @@
|
|||
test
|
||||
train
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,145 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# misc
|
||||
stl10_binary.tar.gz
|
||||
stl10_binary/
|
||||
tb_logs/
|
||||
eigen_values.csv
|
||||
tb_archive/
|
||||
tb_archive_23-12-16/
|
||||
output_data/
|
||||
lightning_logs/
|
||||
custom/
|
||||
checkpoint_semes/
|
||||
archive
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2020 Untiled AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,290 @@
|
|||
# PyTorch-Lightning Implementation of Self-Supervised Learning Methods
|
||||
|
||||
This is a [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) implementation of the following self-supervised representation learning methods:
|
||||
- [MoCo](https://arxiv.org/abs/1911.05722)
|
||||
- [MoCo v2](https://arxiv.org/abs/2003.04297)
|
||||
- [SimCLR](https://arxiv.org/abs/2002.05709)
|
||||
- [BYOL](https://arxiv.org/abs/2006.07733)
|
||||
- [EqCo](https://arxiv.org/abs/2010.01929)
|
||||
- [VICReg](https://arxiv.org/abs/2105.04906)
|
||||
|
||||
Supported datasets: ImageNet, STL-10, and CIFAR-10.
|
||||
|
||||
During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an `sklearn` linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as `train_class_acc` / `valid_class_acc`.
|
||||
|
||||
|
||||
## Installing
|
||||
|
||||
Make sure you're in a fresh `conda` or `venv` environment, then run:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/untitled-ai/self_supervised
|
||||
cd self_supervised
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Replicating our BYOL blog post
|
||||
|
||||
We found some surprising results about the role of batch norm in BYOL. See the blog post [Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL)](https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html) for more details about our experiments.
|
||||
|
||||
You can replicate the results of our blog post by running `python train_blog.py`. The cosine similarity between z and z' is reported as `step_neg_cos` (for negative examples) and `step_pos_cos` (for positive examples). Classification accuracy is reported as `valid_class_acc`.
|
||||
|
||||
## Getting started with MoCo v2
|
||||
|
||||
To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):
|
||||
|
||||
```python
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
from moco import SelfSupervisedMethod
|
||||
from model_params import ModelParams
|
||||
|
||||
os.environ["DATA_PATH"] = "~/data"
|
||||
|
||||
params = ModelParams()
|
||||
model = SelfSupervisedMethod(params)
|
||||
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
```
|
||||
|
||||
For convenience, you can instead pass these parameters as keyword args, for example with `model = SelfSupervisedMethod(batch_size=128)`.
|
||||
|
||||
## VICReg
|
||||
|
||||
To train VICReg rather than MoCo v2, use the following parameters:
|
||||
|
||||
```python
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
from moco import SelfSupervisedMethod
|
||||
from model_params import VICRegParams
|
||||
|
||||
os.environ["DATA_PATH"] = "~/data"
|
||||
|
||||
params = VICRegParams()
|
||||
model = SelfSupervisedMethod(params)
|
||||
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
```
|
||||
|
||||
Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.
|
||||
|
||||
## BYOL
|
||||
|
||||
To train BYOL rather than MoCo v2, use the following parameters:
|
||||
|
||||
```python
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
from moco import SelfSupervisedMethod
|
||||
from model_params import BYOLParams
|
||||
|
||||
os.environ["DATA_PATH"] = "~/data"
|
||||
|
||||
params = BYOLParams()
|
||||
model = SelfSupervisedMethod(params)
|
||||
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
```
|
||||
|
||||
## SimCLR
|
||||
|
||||
To train SimCLR rather than MoCo v2, use the following parameters:
|
||||
|
||||
```python
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
from moco import SelfSupervisedMethod
|
||||
from model_params import SimCLRParams
|
||||
|
||||
os.environ["DATA_PATH"] = "~/data"
|
||||
|
||||
params = SimCLRParams()
|
||||
model = SelfSupervisedMethod(params)
|
||||
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
```
|
||||
|
||||
**Note for multi-GPU setups**: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.
|
||||
|
||||
|
||||
# Evaluating a trained model
|
||||
|
||||
To train a linear classifier on the result:
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from linear_classifier import LinearClassifierMethod
|
||||
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
|
||||
trainer = pl.Trainer(gpus=1, max_epochs=100)
|
||||
|
||||
trainer.fit(linear_model)
|
||||
```
|
||||
|
||||
# Results on STL-10 and ImageNet
|
||||
|
||||
Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.
|
||||
|
||||
Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set.
|
||||
This used 8 gpus with `ddp` and parameters:
|
||||
|
||||
```python
|
||||
hparams = ModelParams(
|
||||
encoder_arch="resnet50",
|
||||
shuffle_batch_norm=True,
|
||||
embedding_dim=2048,
|
||||
mlp_hidden_dim=2048,
|
||||
dataset_name="imagenet",
|
||||
batch_size=32,
|
||||
lr=0.03,
|
||||
max_epochs=200,
|
||||
transform_crop_size=224,
|
||||
num_data_workers=32,
|
||||
gather_keys_for_queue=True,
|
||||
)
|
||||
```
|
||||
|
||||
(the `batch_size` differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu
|
||||
training in `ddp` - the effective number is `batch_size=256`). **Note that for ImageNet we suggest using
|
||||
`val_percent_check=0.1` when calling `pl.Trainer`** to reduce the time fitting the sklearn model.
|
||||
|
||||
|
||||
# All training options
|
||||
|
||||
All possible `hparams` for SelfSupervisedMethod, along with defaults:
|
||||
|
||||
```python
|
||||
class ModelParams:
|
||||
# encoder model selection
|
||||
encoder_arch: str = "resnet18"
|
||||
shuffle_batch_norm: bool = False
|
||||
embedding_dim: int = 512 # must match embedding dim of encoder
|
||||
|
||||
# data-related parameters
|
||||
dataset_name: str = "stl10"
|
||||
batch_size: int = 256
|
||||
|
||||
# MoCo parameters
|
||||
K: int = 65536 # number of examples in queue
|
||||
dim: int = 128
|
||||
m: float = 0.996
|
||||
T: float = 0.2
|
||||
|
||||
# eqco parameters
|
||||
eqco_alpha: int = 65536
|
||||
use_eqco_margin: bool = False
|
||||
use_negative_examples_from_batch: bool = False
|
||||
|
||||
# optimization parameters
|
||||
lr: float = 0.5
|
||||
momentum: float = 0.9
|
||||
weight_decay: float = 1e-4
|
||||
max_epochs: int = 320
|
||||
final_lr_schedule_value: float = 0.0
|
||||
|
||||
# transform parameters
|
||||
transform_s: float = 0.5
|
||||
transform_apply_blur: bool = True
|
||||
|
||||
# Change these to make more like BYOL
|
||||
use_momentum_schedule: bool = False
|
||||
loss_type: str = "ce"
|
||||
use_negative_examples_from_queue: bool = True
|
||||
use_both_augmentations_as_queries: bool = False
|
||||
optimizer_name: str = "sgd"
|
||||
lars_warmup_epochs: int = 1
|
||||
lars_eta: float = 1e-3
|
||||
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
|
||||
loss_constant_factor: float = 1
|
||||
|
||||
# Change these to make more like VICReg
|
||||
use_vicreg_loss: bool = False
|
||||
use_lagging_model: bool = True
|
||||
use_unit_sphere_projection: bool = True
|
||||
invariance_loss_weight: float = 25.0
|
||||
variance_loss_weight: float = 25.0
|
||||
covariance_loss_weight: float = 1.0
|
||||
variance_loss_epsilon: float = 1e-04
|
||||
|
||||
# MLP parameters
|
||||
projection_mlp_layers: int = 2
|
||||
prediction_mlp_layers: int = 0
|
||||
mlp_hidden_dim: int = 512
|
||||
|
||||
mlp_normalization: Optional[str] = None
|
||||
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
|
||||
use_mlp_weight_standardization: bool = False
|
||||
|
||||
# data loader parameters
|
||||
num_data_workers: int = 4
|
||||
drop_last_batch: bool = True
|
||||
pin_data_memory: bool = True
|
||||
gather_keys_for_queue: bool = False
|
||||
```
|
||||
|
||||
A few options require more explanation:
|
||||
|
||||
- **encoder_arch** can be any torchvision model, or can be one of the ResNet models with weight standardization defined in
|
||||
`ws_resnet.py`.
|
||||
|
||||
- **dataset_name** can be `imagenet`, `stl10`, or `cifar10`. `os.environ["DATA_PATH"]` will be used as the path to the data. STL-10 and CIFAR-10 will
|
||||
be downloaded if they do not already exist.
|
||||
|
||||
- **loss_type** can be `ce` (cross entropy) with one of the `use_negative_examples` to correspond to MoCo or `ip` (inner product)
|
||||
with both `use_negative_examples=False` to correspond to BYOL. It can also be `bce`, which is similar to `ip` but applies the
|
||||
binary cross entropy loss function to the result. Or it can be `vic` for VICReg loss.
|
||||
|
||||
- **optimizer_name**, currently just `sgd` or `lars`.
|
||||
|
||||
- **exclude_matching_parameters_from_lars** will remove weight decay and LARS learning rate from matching parameters. Set
|
||||
to `[".bias", ".bn"]` to match BYOL paper implementation.
|
||||
|
||||
- **mlp_normalization** can be None for no normalization, `bn` for batch normalization, `ln` for layer norm, `gn` for group
|
||||
norm, or `br` for [batch renormalization](https://github.com/ludvb/batchrenorm).
|
||||
|
||||
- **prediction_mlp_normalization** defaults to `same` to use the same normalization as above, but can be given any of the
|
||||
above parameters to use a different normalization.
|
||||
|
||||
- **shuffle_batch_norm** and **gather_keys_for_queue** are both related to multi-gpu training. **shuffle_batch_norm**
|
||||
will shuffle the *key* images among GPUs, which is needed for training if batch norm is used. **gather_keys_for_queue**
|
||||
will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue.
|
||||
|
||||
# Training with custom options
|
||||
|
||||
You can train using any settings of the above parameters. This configuration represents the settings from BYOL:
|
||||
|
||||
```python
|
||||
hparams = ModelParams(
|
||||
prediction_mlp_layers=2,
|
||||
mlp_normalization="bn",
|
||||
loss_type="ip",
|
||||
use_negative_examples_from_queue=False,
|
||||
use_both_augmentations_as_queries=True,
|
||||
use_momentum_schedule=True,
|
||||
optimizer_name="lars",
|
||||
exclude_matching_parameters_from_lars=[".bias", ".bn"],
|
||||
loss_constant_factor=2
|
||||
)
|
||||
|
||||
```
|
||||
Or here is our recommended way to modify VICReg for CIFAR-10:
|
||||
```python
|
||||
from model_params import VICRegParams
|
||||
|
||||
hparams = VICRegParams(
|
||||
dataset_name="cifar10",
|
||||
transform_apply_blur=False,
|
||||
mlp_hidden_dim=2048,
|
||||
dim=2048,
|
||||
batch_size=256,
|
||||
lr=0.3,
|
||||
final_lr_schedule_value=0,
|
||||
weight_decay=1e-4,
|
||||
lars_warmup_epochs=10,
|
||||
lars_eta=0.02
|
||||
)
|
||||
```
|
|
@ -0,0 +1,88 @@
|
|||
"""
|
||||
From https://github.com/ludvb/batchrenorm
|
||||
@article{batchrenomalization,
|
||||
author = {Sergey Ioffe},
|
||||
title = {Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models},
|
||||
journal = {arXiv preprint arXiv:1702.03275},
|
||||
year = {2017},
|
||||
}
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"]
|
||||
|
||||
|
||||
class BatchRenorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
eps: float = 1e-3,
|
||||
momentum: float = 0.01,
|
||||
affine: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float))
|
||||
self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float))
|
||||
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
|
||||
self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
|
||||
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))
|
||||
self.affine = affine
|
||||
self.eps = eps
|
||||
self.step = 0
|
||||
self.momentum = momentum
|
||||
|
||||
def _check_input_dim(self, x: torch.Tensor) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
@property
|
||||
def rmax(self) -> torch.Tensor:
|
||||
return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0)
|
||||
|
||||
@property
|
||||
def dmax(self) -> torch.Tensor:
|
||||
return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self._check_input_dim(x)
|
||||
if x.dim() > 2:
|
||||
x = x.transpose(1, -1)
|
||||
if self.training:
|
||||
dims = [i for i in range(x.dim() - 1)]
|
||||
batch_mean = x.mean(dims)
|
||||
batch_std = x.std(dims, unbiased=False) + self.eps
|
||||
r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(
|
||||
1 / self.rmax.item(), self.rmax.item()
|
||||
)
|
||||
d = (
|
||||
(batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std)
|
||||
).clamp_(-self.dmax.item(), self.dmax.item())
|
||||
x = (x - batch_mean) / batch_std * r + d
|
||||
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
|
||||
self.running_std += self.momentum * (batch_std.detach() - self.running_std)
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
x = (x - self.running_mean) / self.running_std
|
||||
if self.affine:
|
||||
x = self.weight * x + self.bias
|
||||
if x.dim() > 2:
|
||||
x = x.transpose(1, -1)
|
||||
return x
|
||||
|
||||
|
||||
class BatchRenorm1d(BatchRenorm):
|
||||
def _check_input_dim(self, x: torch.Tensor) -> None:
|
||||
if x.dim() not in [2, 3]:
|
||||
raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")
|
||||
|
||||
|
||||
class BatchRenorm2d(BatchRenorm):
|
||||
def _check_input_dim(self, x: torch.Tensor) -> None:
|
||||
if x.dim() != 4:
|
||||
raise ValueError("expected 4D input (got {x.dim()}D input)")
|
||||
|
||||
|
||||
class BatchRenorm3d(BatchRenorm):
|
||||
def _check_input_dim(self, x: torch.Tensor) -> None:
|
||||
if x.dim() != 5:
|
||||
raise ValueError("expected 5D input (got {x.dim()}D input)")
|
|
@ -0,0 +1,199 @@
|
|||
# for CustomImageDataset
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torchvision.io import read_image
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from pyts.image import GramianAngularField, MarkovTransitionField
|
||||
# for transforms
|
||||
from torchvision import transforms
|
||||
import attr
|
||||
# for CustomDataloader
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
|
||||
|
||||
# this recreates the ImageFolder function
|
||||
class CustomImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, img_dir, transform=None, target_transform=None, img_size=400):
|
||||
# self.img_labels = pd.read_csv(annotations_file)
|
||||
self.img_labels = self.path_mapper(img_dir)
|
||||
self.img_dir = img_dir
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.gaf_function = GramianAngularField(image_size=img_size,
|
||||
method="difference",
|
||||
sample_range=(0,1))
|
||||
self.mtf_function = MarkovTransitionField(image_size=img_size,
|
||||
n_bins=5)
|
||||
|
||||
|
||||
def path_mapper(self, img_dir):
|
||||
df = pd.DataFrame()
|
||||
img_path = Path(img_dir)
|
||||
# grab all the label folders
|
||||
dirs = [f for f in img_path.iterdir() if f.is_dir()]
|
||||
for dir in dirs:
|
||||
for f in Path.joinpath(img_path, dir).iterdir():
|
||||
new_row = {'file': f.name, 'label': dir.name}
|
||||
# we have the file name, and then the label of the folder it belongs to
|
||||
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
|
||||
return df
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_labels)
|
||||
|
||||
def normalize(self, data):
|
||||
data_normalize = ((data - data.min()) / (data.max() - data.min()))
|
||||
return data_normalize
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = os.path.join(self.img_dir,
|
||||
self.img_labels.iloc[idx,1],
|
||||
self.img_labels.iloc[idx,0])
|
||||
data = np.load(img_path)['data']
|
||||
|
||||
data = data.reshape((1,-1))
|
||||
gaf_image = self.normalize(self.gaf_function.transform(data)[0])
|
||||
mtf_image = gaf_image # to turn off mtf
|
||||
# mtf_image = self.normalize(self.mtf_function.transform(data)[0])
|
||||
image = torch.from_numpy((np.stack([gaf_image, mtf_image], axis=0)).astype(np.float32))
|
||||
# assert image.dtype == torch.float32, "Tensor is not float32!"
|
||||
|
||||
label = self.img_labels.iloc[idx, 1]
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
if self.target_transform:
|
||||
label = self.target_transform(label)
|
||||
return image, label
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class ImageTransforms:
|
||||
img_size: int = 400
|
||||
crop_size: tuple[int,int] = (224,224)
|
||||
normalize_means: list = [0.0, 0.0]
|
||||
normalize_stds: list = [1.0, 1.0]
|
||||
|
||||
def split_transform(self, img) -> torch.Tensor:
|
||||
transform = self.single_transform()
|
||||
return torch.stack((transform(img), transform(img)))
|
||||
|
||||
def single_transform(self):
|
||||
transform_list = [
|
||||
# transforms.ToTensor(),
|
||||
transforms.RandomResizedCrop(self.crop_size,
|
||||
scale=(0.3,0.7),
|
||||
antialias=False),
|
||||
# transforms.RandomCrop(size=(int(self.img_size * 0.4), int(self.img_size * 0.4)))
|
||||
transforms.Normalize(mean=self.normalize_means, std=self.normalize_stds)
|
||||
]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
class CustomDataloader():
|
||||
|
||||
def __init__(self,
|
||||
img_dir: str,
|
||||
img_size: int = 400,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
persistent_workers: bool = True,
|
||||
shuffle: bool = True
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.persistent_workers = persistent_workers
|
||||
self.shuffle = shuffle
|
||||
normalize_means, normalize_stds = normalize_params(img_dir).calculate_mean_std()
|
||||
self.image_transforms = ImageTransforms(img_size=img_size,
|
||||
crop_size=(224,224),
|
||||
normalize_means=normalize_means,
|
||||
normalize_stds=normalize_stds)
|
||||
self.dataset = CustomImageDataset(img_dir=img_dir,
|
||||
img_size=img_size,
|
||||
transform=self.image_transforms.split_transform)
|
||||
|
||||
def get_dataloader(self):
|
||||
return DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
persistent_workers=self.persistent_workers,
|
||||
shuffle=self.shuffle)
|
||||
|
||||
|
||||
class normalize_params():
|
||||
def __init__(self, root_dir):
|
||||
self.root_dir = root_dir
|
||||
self.img_size = len(np.load(self.get_random_file())['data'])
|
||||
self.gaf_function = GramianAngularField(image_size=self.img_size,
|
||||
method="difference",
|
||||
sample_range=(0,1))
|
||||
self.mtf_function = MarkovTransitionField(image_size=self.img_size,
|
||||
n_bins = 5)
|
||||
|
||||
|
||||
def get_random_file(self):
|
||||
# Search for all files in the directory and subdirectories
|
||||
file_list = glob.glob(os.path.join(self.root_dir, '**', '*'), recursive=True)
|
||||
# Filter out directories from the list
|
||||
file_list = [f for f in file_list if os.path.isfile(f)]
|
||||
# If there are no files found, return None or raise an exception
|
||||
if not file_list:
|
||||
raise FileNotFoundError("No files found in the specified directory")
|
||||
# Select and return a random file path
|
||||
return random.choice(file_list)
|
||||
|
||||
def normalize(self, data):
|
||||
data_normalize = ((data - data.min()) / (data.max() - data.min()))
|
||||
return data_normalize
|
||||
|
||||
|
||||
def load_image(self, filepath):
|
||||
data = np.load(filepath)['data'].astype(np.float32)
|
||||
data = data.reshape((1,-1))
|
||||
gaf_image = self.gaf_function.transform(data)[0]
|
||||
mtf_image = gaf_image
|
||||
# mtf_image = self.mtf_function.transform(data)[0]
|
||||
image = (np.stack([gaf_image, mtf_image], axis=0)).astype(np.float32)
|
||||
return image
|
||||
|
||||
def calculate_mean_std(self):
|
||||
# Initialize lists to store the sum and squared sum of pixel values
|
||||
mean_1, mean_2 = 0.0, 0.0
|
||||
std_1, std_2 = 0.0, 0.0
|
||||
num_pixels = 0
|
||||
image_dir = self.root_dir
|
||||
|
||||
# Iterate through all images in the directory
|
||||
for dirpath, dirnames, filenames in os.walk(image_dir):
|
||||
for filename in filenames:
|
||||
# Full path of the file
|
||||
file_path = os.path.join(dirpath, filename)
|
||||
|
||||
if os.path.isfile(file_path) and file_path.endswith(('npz')):
|
||||
img_np = self.load_image(file_path)
|
||||
# img_np = np.array(img) / 255.0 # Normalize to range [0, 1]
|
||||
|
||||
num_pixels += img_np.shape[1] * img_np.shape[2]
|
||||
|
||||
mean_1 += np.sum(img_np[0, :, :])
|
||||
mean_2 += np.sum(img_np[1, :, :])
|
||||
|
||||
std_1 += np.sum(img_np[0, :, :] ** 2)
|
||||
std_2 += np.sum(img_np[1, :, :] ** 2)
|
||||
|
||||
# Calculate mean
|
||||
mean_1 /= num_pixels
|
||||
mean_2 /= num_pixels
|
||||
|
||||
# Calculate standard deviation
|
||||
std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5
|
||||
std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5
|
||||
|
||||
return [mean_1, mean_2], [std_1, std_2]
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
Layer-wise adaptive rate scaling for SGD in PyTorch!
|
||||
Based on https://github.com/noahgolmant/pytorch-lars
|
||||
"""
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class LARS(Optimizer):
|
||||
r"""Implements layer-wise adaptive rate scaling for SGD.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float): base learning rate (\gamma_0)
|
||||
momentum (float, optional): momentum factor (default: 0) ("m")
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
("\beta")
|
||||
eta (float, optional): LARS coefficient
|
||||
max_epoch: maximum training epoch to determine polynomial LR decay.
|
||||
|
||||
Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
|
||||
Large Batch Training of Convolutional Networks:
|
||||
https://arxiv.org/abs/1708.03888
|
||||
|
||||
Example:
|
||||
>>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
|
||||
>>> optimizer.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> optimizer.step()
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1.0, momentum=0.9, weight_decay=0.0005, eta=0.001, max_epoch=200, warmup_epochs=1):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if momentum < 0.0:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if eta < 0.0:
|
||||
raise ValueError("Invalid LARS coefficient value: {}".format(eta))
|
||||
|
||||
self.epoch = 0
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
eta=eta,
|
||||
max_epoch=max_epoch,
|
||||
warmup_epochs=warmup_epochs,
|
||||
use_lars=True,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def step(self, epoch=None, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
epoch: current epoch to calculate polynomial LR decay schedule.
|
||||
if None, uses self.epoch and increments it.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
if epoch is None:
|
||||
epoch = self.epoch
|
||||
self.epoch += 1
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group["weight_decay"]
|
||||
momentum = group["momentum"]
|
||||
eta = group["eta"]
|
||||
lr = group["lr"]
|
||||
warmup_epochs = group["warmup_epochs"]
|
||||
use_lars = group["use_lars"]
|
||||
group["lars_lrs"] = []
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
param_state = self.state[p]
|
||||
d_p = p.grad.data
|
||||
|
||||
weight_norm = torch.norm(p.data)
|
||||
grad_norm = torch.norm(d_p)
|
||||
|
||||
# Global LR computed on polynomial decay schedule
|
||||
warmup = min((1 + float(epoch)) / warmup_epochs, 1)
|
||||
global_lr = lr * warmup
|
||||
|
||||
# Update the momentum term
|
||||
if use_lars:
|
||||
# Compute local learning rate for this layer
|
||||
local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm)
|
||||
actual_lr = local_lr * global_lr
|
||||
group["lars_lrs"].append(actual_lr.item())
|
||||
else:
|
||||
actual_lr = global_lr
|
||||
group["lars_lrs"].append(global_lr)
|
||||
|
||||
if "momentum_buffer" not in param_state:
|
||||
buf = param_state["momentum_buffer"] = torch.zeros_like(p.data)
|
||||
else:
|
||||
buf = param_state["momentum_buffer"]
|
||||
|
||||
buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr)
|
||||
p.data.add_(-buf)
|
||||
|
||||
return loss
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,80 @@
|
|||
from attr import evolve
|
||||
# from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
# from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from lightning.pytorch import Trainer
|
||||
|
||||
from vicreg import SelfSupervisedMethod
|
||||
from model_params import VICRegParams
|
||||
# import utils
|
||||
|
||||
# from torch.utils.data import DataLoader
|
||||
from dataload import CustomDataloader
|
||||
import torch
|
||||
import numpy as np
|
||||
from utils import get_random_file
|
||||
|
||||
# speedup
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
|
||||
# data parameters
|
||||
data_params = list()
|
||||
data_params.append({'train_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/train",
|
||||
'test_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/test",
|
||||
'checkpoint': 'checkpoint_semes/last.ckpt'})
|
||||
|
||||
batch_size = 256
|
||||
num_epochs=20
|
||||
selector = 0
|
||||
|
||||
def main():
|
||||
|
||||
configs = {
|
||||
"vicreg": evolve(VICRegParams(),
|
||||
encoder_arch = "ws_resnet18", # resnet18, resnet34, resnet50
|
||||
max_epochs=num_epochs
|
||||
),
|
||||
}
|
||||
train_path = data_params[selector]['train_path']
|
||||
random_file = get_random_file(train_path)
|
||||
img_size = len(np.load(random_file)['data'])
|
||||
|
||||
|
||||
for seed in range(1): # number of repeats
|
||||
for name, config in configs.items():
|
||||
|
||||
|
||||
method = SelfSupervisedMethod(config)
|
||||
# logger = TensorBoardLogger("tb_logs", name=f"{name}_{seed}")
|
||||
# Define the checkpoint callback to save the best model only at the end of training
|
||||
# checkpoint_callback = ModelCheckpoint(
|
||||
# filename=f'last-v{seed}',
|
||||
# dirpath='checkpoint_semes', # Directory to save the checkpoints
|
||||
# every_n_epochs=10,
|
||||
# # every_n_train_steps=2,
|
||||
# save_last=True, # Save the last model
|
||||
# save_top_k=1,
|
||||
# save_weights_only=True, # Only save the model weights (not the optimizer, etc.)
|
||||
# # save_on_train_epoch_end=True # Save only at the end of the training epoch
|
||||
# )
|
||||
|
||||
|
||||
|
||||
trainer = Trainer(accelerator="gpu",
|
||||
devices=[1],
|
||||
max_epochs=num_epochs)
|
||||
# strategy=DDPStrategy(find_unused_parameters=False)) # to enable multi-gpu, but not necessary for now
|
||||
|
||||
print("--------------------------------------")
|
||||
print(data_params[selector]['checkpoint'])
|
||||
|
||||
|
||||
train_loader = CustomDataloader(img_dir=data_params[selector]['train_path'],
|
||||
img_size=img_size,
|
||||
batch_size=batch_size,
|
||||
).get_dataloader()
|
||||
trainer.fit(model=method, train_dataloaders=train_loader)
|
||||
trainer.save_checkpoint(data_params[selector]['checkpoint'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,63 @@
|
|||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class ModelParams:
|
||||
# encoder model selection
|
||||
encoder_arch: str = "ws_resnet18" # resnet18, resnet34, resnet50
|
||||
# note that embedding_dim is 512 * expansion parameter in ws_resnet
|
||||
embedding_dim: int = 512 * 1 # must match embedding dim of encoder
|
||||
# projection size
|
||||
dim: int = 64
|
||||
|
||||
|
||||
# optimization parameters
|
||||
optimizer_name: str = 'lars'
|
||||
lr: float = 0.5
|
||||
momentum: float = 0.9
|
||||
weight_decay: float = 1e-4
|
||||
max_epochs: int = 10
|
||||
final_lr_schedule_value: float = 0.0
|
||||
lars_warmup_epochs: int = 1
|
||||
lars_eta: float = 1e-3
|
||||
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
|
||||
|
||||
|
||||
# loss parameters
|
||||
loss_constant_factor: float = 1
|
||||
invariance_loss_weight: float = 25.0
|
||||
variance_loss_weight: float = 25.0
|
||||
covariance_loss_weight: float = 1.0
|
||||
variance_loss_epsilon: float = 1e-04
|
||||
kmeans_weight: float = 1e-03
|
||||
|
||||
# MLP parameters
|
||||
projection_mlp_layers: int = 2
|
||||
prediction_mlp_layers: int = 0 # by default prediction mlp is identity
|
||||
mlp_hidden_dim: int = 512
|
||||
mlp_normalization: Optional[str] = None
|
||||
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
|
||||
use_mlp_weight_standardization: bool = False
|
||||
|
||||
|
||||
|
||||
# Differences between these parameters and those used in the paper (on image net):
|
||||
# max_epochs=1000,
|
||||
# lr=1.6,
|
||||
# batch_size=2048,
|
||||
# weight_decay=1e-6,
|
||||
# mlp_hidden_dim=8192,
|
||||
# dim=8192,
|
||||
VICRegParams = partial(
|
||||
ModelParams,
|
||||
exclude_matching_parameters_from_lars=[".bias", ".bn"],
|
||||
projection_mlp_layers=3,
|
||||
final_lr_schedule_value=0.002,
|
||||
mlp_normalization="bn",
|
||||
lars_warmup_epochs=2,
|
||||
kmeans_weight=1e-03,
|
||||
)
|
|
@ -0,0 +1,132 @@
|
|||
aiohappyeyeballs==2.3.5
|
||||
aiohttp==3.10.3
|
||||
aiosignal==1.3.1
|
||||
anyio==4.4.0
|
||||
arrow==1.3.0
|
||||
asttokens==2.4.1
|
||||
attrs==24.2.0
|
||||
beautifulsoup4==4.12.3
|
||||
blessed==1.20.0
|
||||
boto3==1.34.159
|
||||
botocore==1.34.159
|
||||
certifi==2024.7.4
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
comm==0.2.2
|
||||
contourpy==1.2.1
|
||||
croniter==1.3.15
|
||||
cycler==0.12.1
|
||||
dateutils==0.6.12
|
||||
debugpy==1.8.5
|
||||
decorator==5.1.1
|
||||
deepdiff==7.0.1
|
||||
editor==1.6.6
|
||||
executing==2.0.1
|
||||
fastapi==0.88.0
|
||||
fastjsonschema==2.20.0
|
||||
filelock==3.15.4
|
||||
fonttools==4.53.1
|
||||
frozenlist==1.4.1
|
||||
fsspec==2023.12.2
|
||||
h11==0.14.0
|
||||
idna==3.7
|
||||
inquirer==3.4.0
|
||||
ipykernel==6.29.5
|
||||
ipython==8.26.0
|
||||
itsdangerous==2.2.0
|
||||
jedi==0.19.1
|
||||
Jinja2==3.1.4
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2023.12.1
|
||||
jupyter_client==8.6.2
|
||||
jupyter_core==5.7.2
|
||||
kiwisolver==1.4.5
|
||||
lightning==1.9.5
|
||||
lightning-cloud==0.5.70
|
||||
lightning-utilities==0.11.6
|
||||
llvmlite==0.43.0
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==2.1.5
|
||||
matplotlib==3.9.2
|
||||
matplotlib-inline==0.1.7
|
||||
mdurl==0.1.2
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.5
|
||||
nbformat==5.10.4
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.3
|
||||
numba==0.60.0
|
||||
numpy==1.26.3
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
nvidia-nvjitlink-cu12==12.6.20
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
ordered-set==4.1.0
|
||||
packaging==24.1
|
||||
pandas==2.2.2
|
||||
parso==0.8.4
|
||||
pexpect==4.9.0
|
||||
pillow==10.4.0
|
||||
platformdirs==4.2.2
|
||||
plotly==5.23.0
|
||||
prompt_toolkit==3.0.47
|
||||
protobuf==5.27.3
|
||||
psutil==6.0.0
|
||||
ptyprocess==0.7.0
|
||||
pure_eval==0.2.3
|
||||
pydantic==1.10.17
|
||||
Pygments==2.18.0
|
||||
PyJWT==2.9.0
|
||||
pyparsing==3.1.2
|
||||
python-dateutil==2.9.0.post0
|
||||
python-multipart==0.0.9
|
||||
pytorch-lightning==1.9.5
|
||||
pyts==0.13.0
|
||||
pytz==2024.1
|
||||
PyYAML==6.0.2
|
||||
pyzmq==26.1.0
|
||||
readchar==4.2.0
|
||||
referencing==0.35.1
|
||||
requests==2.32.3
|
||||
rich==13.7.1
|
||||
rpds-py==0.20.0
|
||||
runs==1.2.2
|
||||
s3transfer==0.10.2
|
||||
scikit-learn==1.5.1
|
||||
scipy==1.14.0
|
||||
six==1.16.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.5
|
||||
stack-data==0.6.3
|
||||
starlette==0.22.0
|
||||
starsessions==1.3.0
|
||||
sympy==1.13.2
|
||||
tenacity==9.0.0
|
||||
threadpoolctl==3.5.0
|
||||
torch==2.4.0
|
||||
torchmetrics==1.4.1
|
||||
torchvision==0.19.0
|
||||
tornado==6.4.1
|
||||
tqdm==4.66.5
|
||||
traitlets==5.14.3
|
||||
triton==3.0.0
|
||||
types-python-dateutil==2.9.0.20240316
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2024.1
|
||||
urllib3==2.2.2
|
||||
uvicorn==0.30.5
|
||||
wcwidth==0.2.13
|
||||
websocket-client==1.8.0
|
||||
websockets==11.0.3
|
||||
xmod==1.8.1
|
||||
yarl==1.9.4
|
|
@ -0,0 +1,388 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from PIL import Image\n",
|
||||
"import glob, os\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import utils\n",
|
||||
"from moco import SelfSupervisedMethod\n",
|
||||
"# from model_params import EigRegParams\n",
|
||||
"from model_params import VICRegParams\n",
|
||||
"\n",
|
||||
"from attr import evolve\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.decomposition import PCA\n",
|
||||
"\n",
|
||||
"from sklearn.cluster import KMeans\n",
|
||||
"from sklearn.metrics import rand_score, normalized_mutual_info_score\n",
|
||||
"\n",
|
||||
"# data parameters\n",
|
||||
"data_params = list()\n",
|
||||
"# 0 Beef\n",
|
||||
"data_params.append({'resize': 471, \n",
|
||||
" 'batch_size':30,\n",
|
||||
" 'num_clusters': 5,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Beef/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_beef'})\n",
|
||||
"# 1 dist.phal.outl.agegroup\n",
|
||||
"data_params.append({'resize': 81, \n",
|
||||
" 'batch_size':139,\n",
|
||||
" 'num_clusters': 3,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/DistalPhalanxOutlineAgeGroup/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_dist_agegroup'})\n",
|
||||
"# 2 ECG200\n",
|
||||
"data_params.append({'resize': 97, \n",
|
||||
" 'batch_size':100,\n",
|
||||
" 'num_clusters': 2,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECG200/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_ecg200'})\n",
|
||||
"# 3 ECGFiveDays\n",
|
||||
"data_params.append({'resize': 137, \n",
|
||||
" 'batch_size':23,\n",
|
||||
" 'num_clusters': 2,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ECGFiveDays/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_ecg5days'})\n",
|
||||
"# 4 Meat\n",
|
||||
"data_params.append({'resize': 449, \n",
|
||||
" 'batch_size':60,\n",
|
||||
" 'num_clusters': 3,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Meat/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_meat'})\n",
|
||||
"# 5 mote strain\n",
|
||||
"data_params.append({'resize': 85, \n",
|
||||
" 'batch_size': 20,\n",
|
||||
" 'num_clusters': 2,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/MoteStrain/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_motestrain'})\n",
|
||||
"# 6 osuleaf\n",
|
||||
"data_params.append({'resize': 428, \n",
|
||||
" 'batch_size': 64, # 200\n",
|
||||
" 'num_clusters': 6,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/OSULeaf/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_osuleaf'})\n",
|
||||
"# 7 plane\n",
|
||||
"data_params.append({'resize': 145, \n",
|
||||
" 'batch_size': 105,\n",
|
||||
" 'num_clusters': 7,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/Plane/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_plane'})\n",
|
||||
"# 8 proximal_agegroup\n",
|
||||
"data_params.append({'resize': 81, \n",
|
||||
" 'batch_size': 205,\n",
|
||||
" 'num_clusters': 3,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxOutlineAgeGroup/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_prox_agegroup'})\n",
|
||||
"# 9 proximal_tw\n",
|
||||
"data_params.append({'resize': 81, \n",
|
||||
" 'batch_size': 100, # 400\n",
|
||||
" 'num_clusters': 6,\n",
|
||||
" 'train_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/train\",\n",
|
||||
" 'test_path' :\"/home/richard/Projects/06_research/gaf_vicreg/gaf_data/ucr/ProximalPhalanxTW/test\",\n",
|
||||
" 'checkpoint': 'checkpoint_prox_tw'})\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def calculate_mean_std(image_dir):\n",
|
||||
" # Initialize lists to store the sum and squared sum of pixel values\n",
|
||||
" mean_1, mean_2 = 0.0, 0.0\n",
|
||||
" std_1, std_2 = 0.0, 0.0\n",
|
||||
" num_pixels = 0\n",
|
||||
"\n",
|
||||
" # Iterate through all images in the directory\n",
|
||||
" for dirpath, dirnames, filenames in os.walk(image_dir):\n",
|
||||
" for filename in filenames:\n",
|
||||
" # Full path of the file\n",
|
||||
" file_path = os.path.join(dirpath, filename)\n",
|
||||
"\n",
|
||||
" # for img_name in os.listdir(image_dir):\n",
|
||||
" # img_path = os.path.join(image_dir, img_name)\n",
|
||||
" if os.path.isfile(file_path) and file_path.endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff')):\n",
|
||||
" with Image.open(file_path) as img:\n",
|
||||
" # img = img.convert('RGB') # Ensure image is in RGB format\n",
|
||||
" img_np = np.array(img) / 255.0 # Normalize to range [0, 1]\n",
|
||||
" \n",
|
||||
" num_pixels += img_np.shape[0] * img_np.shape[1]\n",
|
||||
" \n",
|
||||
" mean_1 += np.sum(img_np[:, :, 0])\n",
|
||||
" mean_2 += np.sum(img_np[:, :, 1])\n",
|
||||
" \n",
|
||||
" std_1 += np.sum(img_np[:, :, 0] ** 2)\n",
|
||||
" std_2 += np.sum(img_np[:, :, 1] ** 2)\n",
|
||||
"\n",
|
||||
" # Calculate mean\n",
|
||||
" mean_1 /= num_pixels\n",
|
||||
" mean_2 /= num_pixels\n",
|
||||
"\n",
|
||||
" # Calculate standard deviation\n",
|
||||
" std_1 = (std_1 / num_pixels - mean_1 ** 2) ** 0.5\n",
|
||||
" std_2 = (std_2 / num_pixels - mean_2 ** 2) ** 0.5\n",
|
||||
"\n",
|
||||
" return [mean_1, mean_2], [std_1, std_2]\n",
|
||||
"\n",
|
||||
"def list_directories(path):\n",
|
||||
" entries = os.listdir(path)\n",
|
||||
" directories = [ entry for entry in entries if os.path.isdir(os.path.join(path, entry))]\n",
|
||||
" return directories\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def inference(method, classes, transform, path):\n",
|
||||
" batch_size = 32\n",
|
||||
" image_tensors = []\n",
|
||||
" result = []\n",
|
||||
" labels = []\n",
|
||||
"\n",
|
||||
" # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
" # device = torch.device('cpu')\n",
|
||||
" device = torch.device('cuda:2')\n",
|
||||
" method.model.to(device)\n",
|
||||
" method.projection_model.to(device)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" for key in classes:\n",
|
||||
" image_dir = path + '/' + key \n",
|
||||
" for img_name in os.listdir(image_dir):\n",
|
||||
" image_path = os.path.join(image_dir, img_name)\n",
|
||||
" image = Image.open(image_path)\n",
|
||||
" # image = image.convert('RGB')\n",
|
||||
"\n",
|
||||
" # Preprocess the image\n",
|
||||
" input_tensor = transform(image).unsqueeze(0) # Add batch dimension\n",
|
||||
" image_tensors.append(input_tensor)\n",
|
||||
"\n",
|
||||
" # perform batching\n",
|
||||
" if len(image_tensors) == batch_size:\n",
|
||||
" batch_tensor = torch.cat(image_tensors).to(device)\n",
|
||||
" # Use the pre-trained model to extract features\n",
|
||||
" with torch.no_grad():\n",
|
||||
" emb = method.model(batch_tensor)\n",
|
||||
" projection = method.projection_model(emb)\n",
|
||||
" # projection = method.model(input_tensor)\n",
|
||||
" result.extend(projection.cpu())\n",
|
||||
" # reset back to 0\n",
|
||||
" image_tensors = []\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" labels.append(int(key))\n",
|
||||
"\n",
|
||||
" if len(image_tensors) > 0:\n",
|
||||
" batch_tensor = torch.cat(image_tensors).to(device)\n",
|
||||
" # Use the pre-trained model to extract features\n",
|
||||
" with torch.no_grad():\n",
|
||||
" emb = method.model(batch_tensor)\n",
|
||||
" projection = method.projection_model(emb)\n",
|
||||
" # projection = method.model(input_tensor)\n",
|
||||
" result.extend(projection.cpu())\n",
|
||||
"\n",
|
||||
" return result, labels\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Number of runs\n",
|
||||
"num_runs = 10\n",
|
||||
"# Number of results/metrics per run\n",
|
||||
"num_results_per_run = 10\n",
|
||||
"# Create a 2D NumPy array to store the results\n",
|
||||
"ri_results = np.zeros((num_runs, num_results_per_run))\n",
|
||||
"nmi_results = np.zeros((num_runs, num_results_per_run))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"start = 0\n",
|
||||
"end = 9\n",
|
||||
"for run_num in range(num_runs):\n",
|
||||
" for selector in range(start,end+1):\n",
|
||||
"\n",
|
||||
" config = evolve(VICRegParams(), \n",
|
||||
" encoder_arch = \"ws_resnet18\", # resnet18, resnet34, resnet50\n",
|
||||
" dataset_name=\"custom\", \n",
|
||||
" train_path=data_params[selector]['train_path'],\n",
|
||||
" test_path=data_params[selector]['test_path'],\n",
|
||||
" kmeans_weight=0, # it doens't matter since this is not used in the model\n",
|
||||
" num_clusters=data_params[selector]['num_clusters'])\n",
|
||||
" method = SelfSupervisedMethod(config)\n",
|
||||
" # Initialize your ResNet model\n",
|
||||
" checkpoint = data_params[selector]['checkpoint']\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=49-step=50.ckpt'\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=99-step=100.ckpt'\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=149-step=150.ckpt'\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=199-step=200.ckpt'\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=299-step=300.ckpt'\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=399-step=400.ckpt'\n",
|
||||
" # path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/epoch=499-step=500.ckpt'\n",
|
||||
" path = f'/home/richard/Projects/06_research/gaf_vicreg/self_supervised/{checkpoint}/last-v{run_num}.ckpt'\n",
|
||||
" method = method.load_from_checkpoint(path)\n",
|
||||
" # Set the model to evaluation mode\n",
|
||||
" method.eval()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Define transform\n",
|
||||
" path = data_params[selector]['test_path']\n",
|
||||
" normalize_means, normalize_stds = calculate_mean_std(path)\n",
|
||||
" # image_size = data_params[selector]['resize']\n",
|
||||
" # crop_size = int(0.4 * image_size)\n",
|
||||
" transform = transforms.Compose([\n",
|
||||
" # transforms.Resize((image_size, image_size)),\n",
|
||||
"\n",
|
||||
" # transforms.CenterCrop(size=(crop_size, crop_size)),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(mean=normalize_means, std=normalize_stds),\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # get all the classes\n",
|
||||
" classes = list_directories(path)\n",
|
||||
"\n",
|
||||
" result, labels = inference(method, classes, transform, path)\n",
|
||||
"\n",
|
||||
" data = np.array(result)\n",
|
||||
" # pca = PCA(n_components=2)\n",
|
||||
" # reduced_data = pca.fit_transform(data)\n",
|
||||
"\n",
|
||||
" # Choose the number of clusters, say 3\n",
|
||||
" kmeans = KMeans(n_clusters=data_params[selector]['num_clusters'], random_state=42, n_init=10)\n",
|
||||
" clusters = kmeans.fit_predict(data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # print(data_params[selector]['checkpoint'])\n",
|
||||
" # print(\"Rand Index: \", rand_score(clusters, labels))\n",
|
||||
" # print(\"NMI: \", normalized_mutual_info_score(clusters, labels))\n",
|
||||
" rand_index = rand_score(clusters, labels)\n",
|
||||
" nmi = normalized_mutual_info_score(clusters, labels)\n",
|
||||
"\n",
|
||||
" ri_results[run_num,selector] = rand_index\n",
|
||||
" nmi_results[run_num, selector] = nmi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0\n",
|
||||
"RI mean: 0.6374712643678162\n",
|
||||
"RI std: 0.03729597247277404\n",
|
||||
"NMI mean: 0.2602497793608286\n",
|
||||
"NMI std: 0.02985710765468422\n",
|
||||
"1\n",
|
||||
"RI mean: 0.6458659159628819\n",
|
||||
"RI std: 0.01396857176725652\n",
|
||||
"NMI mean: 0.30136574266475413\n",
|
||||
"NMI std: 0.03170806441677236\n",
|
||||
"2\n",
|
||||
"RI mean: 0.6408888888888888\n",
|
||||
"RI std: 0.021672013185081274\n",
|
||||
"NMI mean: 0.20592658284184645\n",
|
||||
"NMI std: 0.048248741710938625\n",
|
||||
"3\n",
|
||||
"RI mean: 0.5765183804661967\n",
|
||||
"RI std: 0.03417303145285537\n",
|
||||
"NMI mean: 0.11716054993167128\n",
|
||||
"NMI std: 0.05173307476499848\n",
|
||||
"4\n",
|
||||
"RI mean: 0.7636723163841806\n",
|
||||
"RI std: 0.0838674066635877\n",
|
||||
"NMI mean: 0.6087294263576666\n",
|
||||
"NMI std: 0.10910741463199608\n",
|
||||
"5\n",
|
||||
"RI mean: 0.6088675385570139\n",
|
||||
"RI std: 0.041236238284731705\n",
|
||||
"NMI mean: 0.17859344790373669\n",
|
||||
"NMI std: 0.08743358257833596\n",
|
||||
"6\n",
|
||||
"RI mean: 0.7343060937553582\n",
|
||||
"RI std: 0.020174409290336055\n",
|
||||
"NMI mean: 0.22234048756150684\n",
|
||||
"NMI std: 0.029705953611425088\n",
|
||||
"7\n",
|
||||
"RI mean: 0.9384065934065934\n",
|
||||
"RI std: 0.019608200939834567\n",
|
||||
"NMI mean: 0.8406540399495203\n",
|
||||
"NMI std: 0.03016596675386891\n",
|
||||
"8\n",
|
||||
"RI mean: 0.7505643232902918\n",
|
||||
"RI std: 0.022756611198669806\n",
|
||||
"NMI mean: 0.48619057929963566\n",
|
||||
"NMI std: 0.01338604938860034\n",
|
||||
"9\n",
|
||||
"RI mean: 0.832415112386418\n",
|
||||
"RI std: 0.019248159640681852\n",
|
||||
"NMI mean: 0.5497811436968876\n",
|
||||
"NMI std: 0.023003346601586715\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for data_select in range(10):\n",
|
||||
"\tprint(data_select)\n",
|
||||
"\tprint(\"RI mean: \", np.mean(ri_results[:,data_select]))\n",
|
||||
"\tprint(\"RI std: \", np.std(ri_results[:,data_select]))\n",
|
||||
"\tprint(\"NMI mean: \", np.mean(nmi_results[:,data_select]))\n",
|
||||
"\tprint(\"NMI std: \", np.std(nmi_results[:,data_select]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"True\n",
|
||||
"4\n",
|
||||
"NVIDIA GeForce RTX 4070\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"print(torch.cuda.is_available()) # Should return True if CUDA is available\n",
|
||||
"print(torch.cuda.device_count()) # Should return the number of GPUs available\n",
|
||||
"print(torch.cuda.get_device_name(0)) # Should return the name of the first GPU, if available\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/tmp/ipykernel_2643050/1394991573.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||||
" checkpoint = torch.load(checkpoint_path)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"checkpoint_path = 'checkpoint_semes/last.ckpt'\n",
|
||||
"checkpoint = torch.load(checkpoint_path)\n",
|
||||
"print(checkpoint.keys()) # Check what is saved in the checkpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checkpoint['hparams_name']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
# import os
|
||||
# import random
|
||||
# from typing import Any
|
||||
# from typing import Callable
|
||||
# from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
# import attr
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import ws_resnet
|
||||
# from model_params import ModelParams
|
||||
|
||||
################
|
||||
# main train utils
|
||||
|
||||
def get_random_file(root_dir):
|
||||
# Search for all files in the directory and subdirectories
|
||||
file_list = glob.glob(os.path.join(root_dir, '**', '*'), recursive=True)
|
||||
# Filter out directories from the list
|
||||
file_list = [f for f in file_list if os.path.isfile(f)]
|
||||
# If there are no files found, return None or raise an exception
|
||||
if not file_list:
|
||||
raise FileNotFoundError("No files found in the specified directory")
|
||||
# Select and return a random file path
|
||||
return random.choice(file_list)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#####################
|
||||
# Parallelism utils #
|
||||
#####################
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
class BatchShuffleDDP:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def shuffle(x):
|
||||
"""
|
||||
Batch shuffle, for making use of BatchNorm.
|
||||
*** Only support DistributedDataParallel (DDP) model. ***
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = concat_all_gather(x)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# random shuffle index
|
||||
idx_shuffle = torch.randperm(batch_size_all).to(x.device)
|
||||
|
||||
# broadcast to all gpus
|
||||
torch.distributed.broadcast(idx_shuffle, src=0)
|
||||
|
||||
# index for restoring
|
||||
idx_unshuffle = torch.argsort(idx_shuffle)
|
||||
|
||||
# shuffled index for this gpu
|
||||
gpu_idx = torch.distributed.get_rank()
|
||||
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
||||
|
||||
return x_gather[idx_this], idx_unshuffle
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def unshuffle(x, idx_unshuffle):
|
||||
"""
|
||||
Undo batch shuffle.
|
||||
*** Only support DistributedDataParallel (DDP) model. ***
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = concat_all_gather(x)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# restored index for this gpu
|
||||
gpu_idx = torch.distributed.get_rank()
|
||||
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
||||
|
||||
return x_gather[idx_this]
|
||||
|
||||
|
||||
###############
|
||||
# Model utils #
|
||||
###############
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self, input_dim, output_dim, hidden_dim, num_layers, weight_standardization=False, normalization=None
|
||||
):
|
||||
super().__init__()
|
||||
assert num_layers >= 0, "negative layers?!?"
|
||||
if normalization is not None:
|
||||
assert callable(normalization), "normalization must be callable"
|
||||
|
||||
if num_layers == 0:
|
||||
self.net = torch.nn.Identity()
|
||||
return
|
||||
|
||||
if num_layers == 1:
|
||||
self.net = torch.nn.Linear(input_dim, output_dim)
|
||||
return
|
||||
|
||||
linear_net = ws_resnet.Linear if weight_standardization else torch.nn.Linear
|
||||
|
||||
layers = []
|
||||
prev_dim = input_dim
|
||||
for _ in range(num_layers - 1):
|
||||
layers.append(linear_net(prev_dim, hidden_dim))
|
||||
if normalization is not None:
|
||||
layers.append(normalization())
|
||||
layers.append(torch.nn.ReLU())
|
||||
prev_dim = hidden_dim
|
||||
|
||||
layers.append(torch.nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
self.net = torch.nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def get_encoder(name: str, **kwargs) -> torch.nn.Module:
|
||||
"""
|
||||
Gets just the encoder portion of a torchvision model (replaces final layer with identity)
|
||||
:param name: (str) name of the model
|
||||
:param kwargs: kwargs to send to the model
|
||||
:return:
|
||||
"""
|
||||
|
||||
if name in ws_resnet.__dict__:
|
||||
model_creator = ws_resnet.__dict__.get(name)
|
||||
elif name in torchvision.models.__dict__:
|
||||
model_creator = torchvision.models.__dict__.get(name)
|
||||
else:
|
||||
raise AttributeError(f"Unknown architecture {name}")
|
||||
|
||||
assert model_creator is not None, f"no torchvision model named {name}"
|
||||
model = model_creator(**kwargs)
|
||||
if hasattr(model, "fc"): # in resnet
|
||||
model.fc = torch.nn.Identity()
|
||||
elif hasattr(model, "classifier"): # not in resnet
|
||||
model.classifier = torch.nn.Identity()
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown class {model.__class__}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
####################
|
||||
# Evaluation utils #
|
||||
####################
|
||||
|
||||
|
||||
def calculate_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def log_softmax_with_factors(logits: torch.Tensor, log_factor: float = 1, neg_factor: float = 1) -> torch.Tensor:
|
||||
exp_sum_neg_logits = torch.exp(logits).sum(dim=-1, keepdim=True) - torch.exp(logits)
|
||||
softmax_result = logits - log_factor * torch.log(torch.exp(logits) + neg_factor * exp_sum_neg_logits)
|
||||
return softmax_result
|
|
@ -0,0 +1,285 @@
|
|||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import attr
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
# from pytorch_lightning.utilities import AttributeDict
|
||||
from lightning.fabric.utilities.data import AttributeDict
|
||||
from torch.utils.data import DataLoader
|
||||
import lightning as L
|
||||
|
||||
import utils
|
||||
from batchrenorm import BatchRenorm1d
|
||||
from lars import LARS
|
||||
from model_params import ModelParams
|
||||
# from sklearn.linear_model import LogisticRegression
|
||||
# from sklearn.cluster import KMeans
|
||||
# from sklearn.metrics import rand_score, normalized_mutual_info_score
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
def get_mlp_normalization(hparams: ModelParams, prediction=False):
|
||||
normalization_str = hparams.mlp_normalization
|
||||
if prediction and hparams.prediction_mlp_normalization != "same":
|
||||
normalization_str = hparams.prediction_mlp_normalization
|
||||
|
||||
if normalization_str is None:
|
||||
return None
|
||||
elif normalization_str == "bn":
|
||||
return partial(torch.nn.BatchNorm1d, num_features=hparams.mlp_hidden_dim)
|
||||
elif normalization_str == "br":
|
||||
return partial(BatchRenorm1d, num_features=hparams.mlp_hidden_dim)
|
||||
elif normalization_str == "ln":
|
||||
return partial(torch.nn.LayerNorm, normalized_shape=[hparams.mlp_hidden_dim])
|
||||
elif normalization_str == "gn":
|
||||
return partial(torch.nn.GroupNorm, num_channels=hparams.mlp_hidden_dim, num_groups=32)
|
||||
else:
|
||||
raise NotImplementedError(f"mlp normalization {normalization_str} not implemented")
|
||||
|
||||
# class KMeansLoss:
|
||||
#
|
||||
# def __init__(self, num_clusters, embedding_dim, device):
|
||||
# self.num_clusters = num_clusters
|
||||
# self.centroids = torch.randn(num_clusters, embedding_dim, device=device)
|
||||
# self.device=device
|
||||
#
|
||||
# def update_centroids(self, embeddings, assignments):
|
||||
# for i in range(self.num_clusters):
|
||||
# assigned_embeddings = embeddings[assignments == i]
|
||||
# if len(assigned_embeddings) > 1: # good if more than singleton
|
||||
# # implement ewma update for centroids
|
||||
# weight1 = torch.tensor(0.3, device='cpu')
|
||||
# weight2 = torch.tensor(0.7, device='cpu') # give more weight to new embeddings
|
||||
# self.centroids[i] = self.centroids[i] * weight1 + assigned_embeddings.mean(dim=0).cpu() * weight2
|
||||
#
|
||||
# def set_centroids(self, embeddings, assignments):
|
||||
# for i in range(self.num_clusters):
|
||||
# assigned_embeddings = embeddings[assignments == i]
|
||||
# if len(assigned_embeddings) > 1: # good if more than singleton
|
||||
# # implement ewma update for centroids
|
||||
# self.centroids[i] = assigned_embeddings.mean(dim=0).cpu()
|
||||
#
|
||||
#
|
||||
# def compute_loss(self, embeddings):
|
||||
# # move centroids to same device as embeddings
|
||||
# centroids = self.centroids.to(embeddings.device)
|
||||
# distances = torch.cdist(embeddings, centroids, p=self.num_clusters)
|
||||
# min_distances, assignments = distances.min(dim=1)
|
||||
# loss = min_distances.pow(2).sum()
|
||||
# return loss, assignments
|
||||
#
|
||||
# def forward(self, embeddings, step_count):
|
||||
# loss, assignments = self.compute_loss(embeddings)
|
||||
# detached_embeddings = embeddings.detach()
|
||||
# detached_assignments = assignments.detach()
|
||||
#
|
||||
# if (step_count < 5):
|
||||
# self.set_centroids(detached_embeddings, detached_assignments)
|
||||
# if (step_count % 2 == 0):
|
||||
# self.update_centroids(detached_embeddings, detached_assignments)
|
||||
# return loss
|
||||
|
||||
|
||||
class SelfSupervisedMethod(L.LightningModule):
|
||||
model: torch.nn.Module
|
||||
hparams: AttributeDict
|
||||
embedding_dim: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hparams: Union[ModelParams, dict, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# disable automatic optimization for lightning2
|
||||
self.automatic_optimization = False
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
|
||||
# load from arguments
|
||||
if hparams is None:
|
||||
hparams = self.params(**kwargs)
|
||||
# if it is already an attributedict, then use it directly
|
||||
if hparams is not None:
|
||||
self.save_hyperparameters(attr.asdict(hparams))
|
||||
|
||||
|
||||
# Create encoder model
|
||||
self.model = utils.get_encoder(hparams.encoder_arch)
|
||||
|
||||
# projection_mlp_layers = 3
|
||||
self.projection_model = utils.MLP(
|
||||
hparams.embedding_dim,
|
||||
hparams.dim,
|
||||
hparams.mlp_hidden_dim,
|
||||
num_layers=hparams.projection_mlp_layers,
|
||||
normalization=get_mlp_normalization(hparams),
|
||||
weight_standardization=hparams.use_mlp_weight_standardization,
|
||||
)
|
||||
|
||||
# by default it is identity
|
||||
# prediction_mlp_layers = 0
|
||||
self.prediction_model = utils.MLP(
|
||||
hparams.dim,
|
||||
hparams.dim,
|
||||
hparams.mlp_hidden_dim,
|
||||
num_layers=hparams.prediction_mlp_layers,
|
||||
normalization=get_mlp_normalization(hparams, prediction=True),
|
||||
weight_standardization=hparams.use_mlp_weight_standardization,
|
||||
)
|
||||
|
||||
# kmeans loss
|
||||
# self.kmeans_loss = KMeansLoss(num_clusters=hparams.num_clusters, embedding_dim=hparams.dim, device=self.device)
|
||||
|
||||
|
||||
def _get_embeddings(self, x):
|
||||
"""
|
||||
Input:
|
||||
im_q: a batch of query images
|
||||
im_k: a batch of key images
|
||||
Output:
|
||||
embeddings, targets
|
||||
"""
|
||||
bsz, nd, nc, nh, nw = x.shape
|
||||
assert nd == 2, "second dimension should be the split image -- dims should be N2CHW"
|
||||
im_q = x[:, 0].contiguous()
|
||||
im_k = x[:, 1].contiguous()
|
||||
|
||||
# compute query features
|
||||
emb_q = self.model(im_q)
|
||||
q_projection = self.projection_model(emb_q)
|
||||
# by default vicreg gives an identity for prediction model
|
||||
q = self.prediction_model(q_projection) # queries: NxC
|
||||
emb_k = self.model(im_k)
|
||||
k_projection = self.projection_model(emb_k)
|
||||
k = self.prediction_model(k_projection) # queries: NxC
|
||||
# q and k are the projection embeddings
|
||||
|
||||
return emb_q, q, k
|
||||
|
||||
|
||||
def _get_vicreg_loss(self, z_a, z_b, batch_idx):
|
||||
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
|
||||
|
||||
# invariance loss
|
||||
loss_inv = F.mse_loss(z_a, z_b)
|
||||
|
||||
# variance loss
|
||||
std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon)
|
||||
std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon)
|
||||
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
|
||||
loss_v_b = torch.mean(F.relu(1 - std_z_b))
|
||||
loss_var = loss_v_a + loss_v_b
|
||||
|
||||
# covariance loss
|
||||
N, D = z_a.shape
|
||||
z_a = z_a - z_a.mean(dim=0)
|
||||
z_b = z_b - z_b.mean(dim=0)
|
||||
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
|
||||
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
|
||||
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
|
||||
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
|
||||
loss_cov = loss_c_a + loss_c_b
|
||||
|
||||
weighted_inv = loss_inv * self.hparams.invariance_loss_weight
|
||||
weighted_var = loss_var * self.hparams.variance_loss_weight
|
||||
weighted_cov = loss_cov * self.hparams.covariance_loss_weight
|
||||
|
||||
loss = weighted_inv + weighted_var + weighted_cov
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
||||
x, class_labels = batch # batch is a tuple, we just want the image
|
||||
|
||||
emb_q, q, k = self._get_embeddings(x)
|
||||
|
||||
vicreg_loss = self._get_vicreg_loss(q, k, batch_idx)
|
||||
|
||||
total_loss = vicreg_loss.mean() * self.hparams.loss_constant_factor
|
||||
|
||||
# here lies the manual optimizing code
|
||||
self.optimizer.zero_grad()
|
||||
self.manual_backward(total_loss)
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
log_data = {
|
||||
"step_train_loss": total_loss,
|
||||
}
|
||||
|
||||
self.log_dict(log_data, sync_dist=True, prog_bar=True)
|
||||
return {"loss": total_loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
# exclude bias and batch norm from LARS and weight decay
|
||||
regular_parameters = []
|
||||
regular_parameter_names = []
|
||||
excluded_parameters = []
|
||||
excluded_parameter_names = []
|
||||
for name, parameter in self.named_parameters():
|
||||
if parameter.requires_grad is False:
|
||||
continue
|
||||
|
||||
# for vicreg
|
||||
# exclude_matching_parameters_from_lars=[".bias", ".bn"],
|
||||
if any(x in name for x in self.hparams.exclude_matching_parameters_from_lars):
|
||||
excluded_parameters.append(parameter)
|
||||
excluded_parameter_names.append(name)
|
||||
else:
|
||||
regular_parameters.append(parameter)
|
||||
regular_parameter_names.append(name)
|
||||
|
||||
param_groups = [
|
||||
{
|
||||
"params": regular_parameters,
|
||||
"names": regular_parameter_names,
|
||||
"use_lars": True
|
||||
},
|
||||
{
|
||||
"params": excluded_parameters,
|
||||
"names": excluded_parameter_names,
|
||||
"use_lars": False,
|
||||
"weight_decay": 0,
|
||||
},
|
||||
]
|
||||
if self.hparams.optimizer_name == "sgd":
|
||||
optimizer = torch.optim.SGD
|
||||
elif self.hparams.optimizer_name == "lars":
|
||||
optimizer = partial(LARS, warmup_epochs=self.hparams.lars_warmup_epochs, eta=self.hparams.lars_eta)
|
||||
else:
|
||||
raise NotImplementedError(f"No such optimizer {self.hparams.optimizer_name}")
|
||||
|
||||
self.optimizer = optimizer(
|
||||
param_groups,
|
||||
lr=self.hparams.lr,
|
||||
momentum=self.hparams.momentum,
|
||||
weight_decay=self.hparams.weight_decay,
|
||||
)
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer,
|
||||
self.hparams.max_epochs,
|
||||
eta_min=self.hparams.final_lr_schedule_value,
|
||||
)
|
||||
return None # [encoding_optimizer], [self.lr_scheduler]
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def params(cls, **kwargs) -> ModelParams:
|
||||
return ModelParams(**kwargs)
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
From https://github.com/joe-siyuan-qiao/pytorch-classification
|
||||
@article{weightstandardization,
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
|
||||
title = {Weight Standardization},
|
||||
journal = {arXiv preprint arXiv:1903.10520},
|
||||
year = {2019},
|
||||
}
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
# return super(Conv2d, self).forward(x)
|
||||
weight = self.weight
|
||||
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
|
||||
weight = weight - weight_mean
|
||||
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
||||
weight = weight / std.expand_as(weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
weight_mean = weight.mean(dim=1, keepdim=True)
|
||||
weight = weight - weight_mean
|
||||
std = weight.std(dim=1, keepdim=True) + 1e-5
|
||||
weight = weight / std.expand_as(weight)
|
||||
return F.linear(x, weight, self.bias)
|
||||
|
||||
|
||||
def BatchNorm2d(num_features):
|
||||
return nn.GroupNorm(num_channels=num_features, num_groups=32)
|
||||
|
||||
|
||||
__all__ = ["ws_resnet18", "ws_resnet34", "ws_resnet50", "ws_resnet101", "ws_resnet152"]
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4 # so that it can run
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.bn1 = BatchNorm2d(planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn2 = BatchNorm2d(planes)
|
||||
self.conv3 = conv1x1(planes, planes * self.expansion)
|
||||
self.bn3 = BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
# print(f"Input shape: {x.shape}")
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# print(f"After conv1: {out.shape}")
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
# print(f"After conv2: {out.shape}")
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
# print(f"After conv3: {out.shape}")
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
# print(f"After downsample: {identity.shape}")
|
||||
# assert out.shape == identity.shape, f"Shape mismatch: out {out.shape}, identity {identity.shape}"
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
# print(f"Output shape: {out.shape}")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
chan_size = 2
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
|
||||
super(ResNet, self).__init__()
|
||||
self.inplanes = 64
|
||||
self.conv1 = Conv2d(in_channels=chan_size, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if (stride != 1) or (self.inplanes != planes * block.expansion):
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def ws_resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def ws_resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def ws_resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def ws_resnet101(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def ws_resnet152(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
return model
|
Loading…
Reference in New Issue