291 lines
9.5 KiB
Markdown
291 lines
9.5 KiB
Markdown
# PyTorch-Lightning Implementation of Self-Supervised Learning Methods
|
|
|
|
This is a [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) implementation of the following self-supervised representation learning methods:
|
|
- [MoCo](https://arxiv.org/abs/1911.05722)
|
|
- [MoCo v2](https://arxiv.org/abs/2003.04297)
|
|
- [SimCLR](https://arxiv.org/abs/2002.05709)
|
|
- [BYOL](https://arxiv.org/abs/2006.07733)
|
|
- [EqCo](https://arxiv.org/abs/2010.01929)
|
|
- [VICReg](https://arxiv.org/abs/2105.04906)
|
|
|
|
Supported datasets: ImageNet, STL-10, and CIFAR-10.
|
|
|
|
During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an `sklearn` linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as `train_class_acc` / `valid_class_acc`.
|
|
|
|
|
|
## Installing
|
|
|
|
Make sure you're in a fresh `conda` or `venv` environment, then run:
|
|
|
|
```bash
|
|
git clone https://github.com/untitled-ai/self_supervised
|
|
cd self_supervised
|
|
pip install -r requirements.txt
|
|
```
|
|
|
|
## Replicating our BYOL blog post
|
|
|
|
We found some surprising results about the role of batch norm in BYOL. See the blog post [Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL)](https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html) for more details about our experiments.
|
|
|
|
You can replicate the results of our blog post by running `python train_blog.py`. The cosine similarity between z and z' is reported as `step_neg_cos` (for negative examples) and `step_pos_cos` (for positive examples). Classification accuracy is reported as `valid_class_acc`.
|
|
|
|
## Getting started with MoCo v2
|
|
|
|
To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration):
|
|
|
|
```python
|
|
import os
|
|
import pytorch_lightning as pl
|
|
from moco import SelfSupervisedMethod
|
|
from model_params import ModelParams
|
|
|
|
os.environ["DATA_PATH"] = "~/data"
|
|
|
|
params = ModelParams()
|
|
model = SelfSupervisedMethod(params)
|
|
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
|
trainer.fit(model)
|
|
trainer.save_checkpoint("example.ckpt")
|
|
```
|
|
|
|
For convenience, you can instead pass these parameters as keyword args, for example with `model = SelfSupervisedMethod(batch_size=128)`.
|
|
|
|
## VICReg
|
|
|
|
To train VICReg rather than MoCo v2, use the following parameters:
|
|
|
|
```python
|
|
import os
|
|
import pytorch_lightning as pl
|
|
from moco import SelfSupervisedMethod
|
|
from model_params import VICRegParams
|
|
|
|
os.environ["DATA_PATH"] = "~/data"
|
|
|
|
params = VICRegParams()
|
|
model = SelfSupervisedMethod(params)
|
|
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
|
trainer.fit(model)
|
|
trainer.save_checkpoint("example.ckpt")
|
|
```
|
|
|
|
Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details.
|
|
|
|
## BYOL
|
|
|
|
To train BYOL rather than MoCo v2, use the following parameters:
|
|
|
|
```python
|
|
import os
|
|
import pytorch_lightning as pl
|
|
from moco import SelfSupervisedMethod
|
|
from model_params import BYOLParams
|
|
|
|
os.environ["DATA_PATH"] = "~/data"
|
|
|
|
params = BYOLParams()
|
|
model = SelfSupervisedMethod(params)
|
|
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
|
trainer.fit(model)
|
|
trainer.save_checkpoint("example.ckpt")
|
|
```
|
|
|
|
## SimCLR
|
|
|
|
To train SimCLR rather than MoCo v2, use the following parameters:
|
|
|
|
```python
|
|
import os
|
|
import pytorch_lightning as pl
|
|
from moco import SelfSupervisedMethod
|
|
from model_params import SimCLRParams
|
|
|
|
os.environ["DATA_PATH"] = "~/data"
|
|
|
|
params = SimCLRParams()
|
|
model = SelfSupervisedMethod(params)
|
|
trainer = pl.Trainer(gpus=1, max_epochs=320)
|
|
trainer.fit(model)
|
|
trainer.save_checkpoint("example.ckpt")
|
|
```
|
|
|
|
**Note for multi-GPU setups**: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs.
|
|
|
|
|
|
# Evaluating a trained model
|
|
|
|
To train a linear classifier on the result:
|
|
|
|
```python
|
|
import pytorch_lightning as pl
|
|
from linear_classifier import LinearClassifierMethod
|
|
linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt")
|
|
trainer = pl.Trainer(gpus=1, max_epochs=100)
|
|
|
|
trainer.fit(linear_model)
|
|
```
|
|
|
|
# Results on STL-10 and ImageNet
|
|
|
|
Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters.
|
|
|
|
Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set.
|
|
This used 8 gpus with `ddp` and parameters:
|
|
|
|
```python
|
|
hparams = ModelParams(
|
|
encoder_arch="resnet50",
|
|
shuffle_batch_norm=True,
|
|
embedding_dim=2048,
|
|
mlp_hidden_dim=2048,
|
|
dataset_name="imagenet",
|
|
batch_size=32,
|
|
lr=0.03,
|
|
max_epochs=200,
|
|
transform_crop_size=224,
|
|
num_data_workers=32,
|
|
gather_keys_for_queue=True,
|
|
)
|
|
```
|
|
|
|
(the `batch_size` differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu
|
|
training in `ddp` - the effective number is `batch_size=256`). **Note that for ImageNet we suggest using
|
|
`val_percent_check=0.1` when calling `pl.Trainer`** to reduce the time fitting the sklearn model.
|
|
|
|
|
|
# All training options
|
|
|
|
All possible `hparams` for SelfSupervisedMethod, along with defaults:
|
|
|
|
```python
|
|
class ModelParams:
|
|
# encoder model selection
|
|
encoder_arch: str = "resnet18"
|
|
shuffle_batch_norm: bool = False
|
|
embedding_dim: int = 512 # must match embedding dim of encoder
|
|
|
|
# data-related parameters
|
|
dataset_name: str = "stl10"
|
|
batch_size: int = 256
|
|
|
|
# MoCo parameters
|
|
K: int = 65536 # number of examples in queue
|
|
dim: int = 128
|
|
m: float = 0.996
|
|
T: float = 0.2
|
|
|
|
# eqco parameters
|
|
eqco_alpha: int = 65536
|
|
use_eqco_margin: bool = False
|
|
use_negative_examples_from_batch: bool = False
|
|
|
|
# optimization parameters
|
|
lr: float = 0.5
|
|
momentum: float = 0.9
|
|
weight_decay: float = 1e-4
|
|
max_epochs: int = 320
|
|
final_lr_schedule_value: float = 0.0
|
|
|
|
# transform parameters
|
|
transform_s: float = 0.5
|
|
transform_apply_blur: bool = True
|
|
|
|
# Change these to make more like BYOL
|
|
use_momentum_schedule: bool = False
|
|
loss_type: str = "ce"
|
|
use_negative_examples_from_queue: bool = True
|
|
use_both_augmentations_as_queries: bool = False
|
|
optimizer_name: str = "sgd"
|
|
lars_warmup_epochs: int = 1
|
|
lars_eta: float = 1e-3
|
|
exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper
|
|
loss_constant_factor: float = 1
|
|
|
|
# Change these to make more like VICReg
|
|
use_vicreg_loss: bool = False
|
|
use_lagging_model: bool = True
|
|
use_unit_sphere_projection: bool = True
|
|
invariance_loss_weight: float = 25.0
|
|
variance_loss_weight: float = 25.0
|
|
covariance_loss_weight: float = 1.0
|
|
variance_loss_epsilon: float = 1e-04
|
|
|
|
# MLP parameters
|
|
projection_mlp_layers: int = 2
|
|
prediction_mlp_layers: int = 0
|
|
mlp_hidden_dim: int = 512
|
|
|
|
mlp_normalization: Optional[str] = None
|
|
prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization
|
|
use_mlp_weight_standardization: bool = False
|
|
|
|
# data loader parameters
|
|
num_data_workers: int = 4
|
|
drop_last_batch: bool = True
|
|
pin_data_memory: bool = True
|
|
gather_keys_for_queue: bool = False
|
|
```
|
|
|
|
A few options require more explanation:
|
|
|
|
- **encoder_arch** can be any torchvision model, or can be one of the ResNet models with weight standardization defined in
|
|
`ws_resnet.py`.
|
|
|
|
- **dataset_name** can be `imagenet`, `stl10`, or `cifar10`. `os.environ["DATA_PATH"]` will be used as the path to the data. STL-10 and CIFAR-10 will
|
|
be downloaded if they do not already exist.
|
|
|
|
- **loss_type** can be `ce` (cross entropy) with one of the `use_negative_examples` to correspond to MoCo or `ip` (inner product)
|
|
with both `use_negative_examples=False` to correspond to BYOL. It can also be `bce`, which is similar to `ip` but applies the
|
|
binary cross entropy loss function to the result. Or it can be `vic` for VICReg loss.
|
|
|
|
- **optimizer_name**, currently just `sgd` or `lars`.
|
|
|
|
- **exclude_matching_parameters_from_lars** will remove weight decay and LARS learning rate from matching parameters. Set
|
|
to `[".bias", ".bn"]` to match BYOL paper implementation.
|
|
|
|
- **mlp_normalization** can be None for no normalization, `bn` for batch normalization, `ln` for layer norm, `gn` for group
|
|
norm, or `br` for [batch renormalization](https://github.com/ludvb/batchrenorm).
|
|
|
|
- **prediction_mlp_normalization** defaults to `same` to use the same normalization as above, but can be given any of the
|
|
above parameters to use a different normalization.
|
|
|
|
- **shuffle_batch_norm** and **gather_keys_for_queue** are both related to multi-gpu training. **shuffle_batch_norm**
|
|
will shuffle the *key* images among GPUs, which is needed for training if batch norm is used. **gather_keys_for_queue**
|
|
will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue.
|
|
|
|
# Training with custom options
|
|
|
|
You can train using any settings of the above parameters. This configuration represents the settings from BYOL:
|
|
|
|
```python
|
|
hparams = ModelParams(
|
|
prediction_mlp_layers=2,
|
|
mlp_normalization="bn",
|
|
loss_type="ip",
|
|
use_negative_examples_from_queue=False,
|
|
use_both_augmentations_as_queries=True,
|
|
use_momentum_schedule=True,
|
|
optimizer_name="lars",
|
|
exclude_matching_parameters_from_lars=[".bias", ".bn"],
|
|
loss_constant_factor=2
|
|
)
|
|
|
|
```
|
|
Or here is our recommended way to modify VICReg for CIFAR-10:
|
|
```python
|
|
from model_params import VICRegParams
|
|
|
|
hparams = VICRegParams(
|
|
dataset_name="cifar10",
|
|
transform_apply_blur=False,
|
|
mlp_hidden_dim=2048,
|
|
dim=2048,
|
|
batch_size=256,
|
|
lr=0.3,
|
|
final_lr_schedule_value=0,
|
|
weight_decay=1e-4,
|
|
lars_warmup_epochs=10,
|
|
lars_eta=0.02
|
|
)
|
|
```
|