64 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			64 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
from functools import partial
 | 
						|
from typing import List
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
import attr
 | 
						|
 | 
						|
 | 
						|
@attr.s(auto_attribs=True)
 | 
						|
class ModelParams:
 | 
						|
    # encoder model selection
 | 
						|
    encoder_arch: str = "ws_resnet18" # resnet18, resnet34, resnet50
 | 
						|
    # note that embedding_dim is 512 * expansion parameter in ws_resnet
 | 
						|
    embedding_dim: int = 512 * 1  # must match embedding dim of encoder
 | 
						|
    # projection size
 | 
						|
    dim: int = 64 
 | 
						|
 | 
						|
 | 
						|
    # optimization parameters
 | 
						|
    optimizer_name: str = 'lars'
 | 
						|
    lr: float = 0.5
 | 
						|
    momentum: float = 0.9
 | 
						|
    weight_decay: float = 1e-4
 | 
						|
    max_epochs: int = 10
 | 
						|
    final_lr_schedule_value: float = 0.0
 | 
						|
    lars_warmup_epochs: int = 1
 | 
						|
    lars_eta: float = 1e-3
 | 
						|
    exclude_matching_parameters_from_lars: List[str] = []  # set to [".bias", ".bn"] to match paper
 | 
						|
 | 
						|
 | 
						|
    # loss parameters
 | 
						|
    loss_constant_factor: float = 1
 | 
						|
    invariance_loss_weight: float = 25.0
 | 
						|
    variance_loss_weight: float = 25.0
 | 
						|
    covariance_loss_weight: float = 1.0
 | 
						|
    variance_loss_epsilon: float = 1e-04
 | 
						|
    kmeans_weight: float = 1e-03
 | 
						|
 | 
						|
    # MLP parameters
 | 
						|
    projection_mlp_layers: int = 2
 | 
						|
    prediction_mlp_layers: int = 0 # by default prediction mlp is identity
 | 
						|
    mlp_hidden_dim: int = 512
 | 
						|
    mlp_normalization: Optional[str] = None
 | 
						|
    prediction_mlp_normalization: Optional[str] = "same"  # if same will use mlp_normalization
 | 
						|
    use_mlp_weight_standardization: bool = False
 | 
						|
 | 
						|
 | 
						|
 | 
						|
# Differences between these parameters and those used in the paper (on image net):
 | 
						|
# max_epochs=1000,
 | 
						|
# lr=1.6,
 | 
						|
# batch_size=2048,
 | 
						|
# weight_decay=1e-6,
 | 
						|
# mlp_hidden_dim=8192,
 | 
						|
# dim=8192,
 | 
						|
VICRegParams = partial(
 | 
						|
    ModelParams,
 | 
						|
    exclude_matching_parameters_from_lars=[".bias", ".bn"],
 | 
						|
    projection_mlp_layers=3,
 | 
						|
    final_lr_schedule_value=0.002,
 | 
						|
    mlp_normalization="bn",
 | 
						|
    lars_warmup_epochs=2,
 | 
						|
    kmeans_weight=1e-03,
 | 
						|
)
 |