semes_gaf/self_supervised/main_train.py

81 lines
2.9 KiB
Python

from attr import evolve
# from lightning.pytorch.callbacks import ModelCheckpoint
# from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from vicreg import SelfSupervisedMethod
from model_params import VICRegParams
# import utils
# from torch.utils.data import DataLoader
from dataload import CustomDataloader
import torch
import numpy as np
from utils import get_random_file
# speedup
torch.set_float32_matmul_precision('medium')
# data parameters
data_params = list()
data_params.append({'train_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/train",
'test_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/test",
'checkpoint': 'checkpoint_semes/last.ckpt'})
batch_size = 256
num_epochs=20
selector = 0
def main():
configs = {
"vicreg": evolve(VICRegParams(),
encoder_arch = "ws_resnet18", # resnet18, resnet34, resnet50
max_epochs=num_epochs
),
}
train_path = data_params[selector]['train_path']
random_file = get_random_file(train_path)
img_size = len(np.load(random_file)['data'])
for seed in range(1): # number of repeats
for name, config in configs.items():
method = SelfSupervisedMethod(config)
# logger = TensorBoardLogger("tb_logs", name=f"{name}_{seed}")
# Define the checkpoint callback to save the best model only at the end of training
# checkpoint_callback = ModelCheckpoint(
# filename=f'last-v{seed}',
# dirpath='checkpoint_semes', # Directory to save the checkpoints
# every_n_epochs=10,
# # every_n_train_steps=2,
# save_last=True, # Save the last model
# save_top_k=1,
# save_weights_only=True, # Only save the model weights (not the optimizer, etc.)
# # save_on_train_epoch_end=True # Save only at the end of the training epoch
# )
trainer = Trainer(accelerator="gpu",
devices=[1],
max_epochs=num_epochs)
# strategy=DDPStrategy(find_unused_parameters=False)) # to enable multi-gpu, but not necessary for now
print("--------------------------------------")
print(data_params[selector]['checkpoint'])
train_loader = CustomDataloader(img_dir=data_params[selector]['train_path'],
img_size=img_size,
batch_size=batch_size,
).get_dataloader()
trainer.fit(model=method, train_dataloaders=train_loader)
trainer.save_checkpoint(data_params[selector]['checkpoint'])
if __name__ == "__main__":
main()