semes_gaf/self_supervised/README.md

291 lines
9.5 KiB
Markdown

# PyTorch-Lightning Implementation of Self-Supervised Learning Methods
This is a [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) implementation of the following self-supervised representation learning methods:
- [MoCo](https://arxiv.org/abs/1911.05722)
- [MoCo v2](https://arxiv.org/abs/2003.04297)
- [SimCLR](https://arxiv.org/abs/2002.05709)
- [BYOL](https://arxiv.org/abs/2006.07733)
- [EqCo](https://arxiv.org/abs/2010.01929)
- [VICReg](https://arxiv.org/abs/2105.04906)
Supported datasets: ImageNet, STL-10, and CIFAR-10.
During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an `sklearn` linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as `train_class_acc` / `valid_class_acc`.
## Installing
Make sure you're in a fresh `conda` or `venv` environment, then run:
```bash
git clone https://github.com/untitled-ai/self_supervised
cd self_supervised
pip install -r requirements.txt
```
## Replicating our BYOL blog post
We found some surprising results about the role of batch norm in BYOL. See the blog post [Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL)](https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html) for more details about our experiments.
You can replicate the results of our blog post by running `python train_blog.py`. The cosine similarity between z and z' is reported as `step_neg_cos` (for negative examples) and `step_pos_cos` (for positive examples). Classification accuracy is reported as `valid_class_acc`.
## Getting started with MoCo v2
To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import ModelParams
os.environ["DATA_PATH"] = "~/data"
params = ModelParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
For convenience, you can instead pass these parameters as keyword args, for example with `model = SelfSupervisedMethod(batch_size=128)`.
## VICReg
To train VICReg rather than MoCo v2, use the following parameters:
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import VICRegParams
os.environ["DATA_PATH"] = "~/data"
params = VICRegParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.
## BYOL
To train BYOL rather than MoCo v2, use the following parameters:
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import BYOLParams
os.environ["DATA_PATH"] = "~/data"
params = BYOLParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
## SimCLR
To train SimCLR rather than MoCo v2, use the following parameters:
```python
import os
import pytorch_lightning as pl
from moco import SelfSupervisedMethod
from model_params import SimCLRParams
os.environ["DATA_PATH"] = "~/data"
params = SimCLRParams()
model = SelfSupervisedMethod(params)
trainer = pl.Trainer(gpus=1, max_epochs=320)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
```
**Note for multi-GPU setups**: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.
# Evaluating a trained model
To train a linear classifier on the result:
```python
import pytorch_lightning as pl
from linear_classifier import LinearClassifierMethod
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
trainer = pl.Trainer(gpus=1, max_epochs=100)
trainer.fit(linear_model)
```
# Results on STL-10 and ImageNet
Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.
Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set.
This used 8 gpus with `ddp` and parameters:
```python
hparams = ModelParams(
encoder_arch="resnet50",
shuffle_batch_norm=True,
embedding_dim=2048,
mlp_hidden_dim=2048,
dataset_name="imagenet",
batch_size=32,
lr=0.03,
max_epochs=200,
transform_crop_size=224,
num_data_workers=32,
gather_keys_for_queue=True,
)
```
(the `batch_size` differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu
training in `ddp` - the effective number is `batch_size=256`). **Note that for ImageNet we suggest using
`val_percent_check=0.1` when calling `pl.Trainer`** to reduce the time fitting the sklearn model.
# All training options
All possible `hparams` for SelfSupervisedMethod, along with defaults:
```python
class ModelParams:
# encoder model selection
encoder_arch: str = "resnet18"
shuffle_batch_norm: bool = False
embedding_dim: int = 512 # must match embedding dim of encoder
# data-related parameters
dataset_name: str = "stl10"
batch_size: int = 256
# MoCo parameters
K: int = 65536 # number of examples in queue
dim: int = 128
m: float = 0.996
T: float = 0.2
# eqco parameters
eqco_alpha: int = 65536
use_eqco_margin: bool = False
use_negative_examples_from_batch: bool = False
# optimization parameters
lr: float = 0.5
momentum: float = 0.9
weight_decay: float = 1e-4
max_epochs: int = 320
final_lr_schedule_value: float = 0.0
# transform parameters
transform_s: float = 0.5
transform_apply_blur: bool = True
# Change these to make more like BYOL
use_momentum_schedule: bool = False
loss_type: str = "ce"
use_negative_examples_from_queue: bool = True
use_both_augmentations_as_queries: bool = False
optimizer_name: str = "sgd"
lars_warmup_epochs: int = 1
lars_eta: float = 1e-3
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
loss_constant_factor: float = 1
# Change these to make more like VICReg
use_vicreg_loss: bool = False
use_lagging_model: bool = True
use_unit_sphere_projection: bool = True
invariance_loss_weight: float = 25.0
variance_loss_weight: float = 25.0
covariance_loss_weight: float = 1.0
variance_loss_epsilon: float = 1e-04
# MLP parameters
projection_mlp_layers: int = 2
prediction_mlp_layers: int = 0
mlp_hidden_dim: int = 512
mlp_normalization: Optional[str] = None
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
use_mlp_weight_standardization: bool = False
# data loader parameters
num_data_workers: int = 4
drop_last_batch: bool = True
pin_data_memory: bool = True
gather_keys_for_queue: bool = False
```
A few options require more explanation:
- **encoder_arch** can be any torchvision model, or can be one of the ResNet models with weight standardization defined in
`ws_resnet.py`.
- **dataset_name** can be `imagenet`, `stl10`, or `cifar10`. `os.environ["DATA_PATH"]` will be used as the path to the data. STL-10 and CIFAR-10 will
be downloaded if they do not already exist.
- **loss_type** can be `ce` (cross entropy) with one of the `use_negative_examples` to correspond to MoCo or `ip` (inner product)
with both `use_negative_examples=False` to correspond to BYOL. It can also be `bce`, which is similar to `ip` but applies the
binary cross entropy loss function to the result. Or it can be `vic` for VICReg loss.
- **optimizer_name**, currently just `sgd` or `lars`.
- **exclude_matching_parameters_from_lars** will remove weight decay and LARS learning rate from matching parameters. Set
to `[".bias", ".bn"]` to match BYOL paper implementation.
- **mlp_normalization** can be None for no normalization, `bn` for batch normalization, `ln` for layer norm, `gn` for group
norm, or `br` for [batch renormalization](https://github.com/ludvb/batchrenorm).
- **prediction_mlp_normalization** defaults to `same` to use the same normalization as above, but can be given any of the
above parameters to use a different normalization.
- **shuffle_batch_norm** and **gather_keys_for_queue** are both related to multi-gpu training. **shuffle_batch_norm**
will shuffle the *key* images among GPUs, which is needed for training if batch norm is used. **gather_keys_for_queue**
will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue.
# Training with custom options
You can train using any settings of the above parameters. This configuration represents the settings from BYOL:
```python
hparams = ModelParams(
prediction_mlp_layers=2,
mlp_normalization="bn",
loss_type="ip",
use_negative_examples_from_queue=False,
use_both_augmentations_as_queries=True,
use_momentum_schedule=True,
optimizer_name="lars",
exclude_matching_parameters_from_lars=[".bias", ".bn"],
loss_constant_factor=2
)
```
Or here is our recommended way to modify VICReg for CIFAR-10:
```python
from model_params import VICRegParams
hparams = VICRegParams(
dataset_name="cifar10",
transform_apply_blur=False,
mlp_hidden_dim=2048,
dim=2048,
batch_size=256,
lr=0.3,
final_lr_schedule_value=0,
weight_decay=1e-4,
lars_warmup_epochs=10,
lars_eta=0.02
)
```