from attr import evolve # from lightning.pytorch.callbacks import ModelCheckpoint # from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch import Trainer from vicreg import SelfSupervisedMethod from model_params import VICRegParams # import utils # from torch.utils.data import DataLoader from dataload import CustomDataloader import torch import numpy as np from utils import get_random_file # speedup torch.set_float32_matmul_precision('medium') # data parameters data_params = list() data_params.append({'train_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/train", 'test_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/test", 'checkpoint': 'checkpoint_semes/last.ckpt'}) batch_size = 256 num_epochs=20 selector = 0 def main(): configs = { "vicreg": evolve(VICRegParams(), encoder_arch = "ws_resnet18", # resnet18, resnet34, resnet50 max_epochs=num_epochs ), } train_path = data_params[selector]['train_path'] random_file = get_random_file(train_path) img_size = len(np.load(random_file)['data']) for seed in range(1): # number of repeats for name, config in configs.items(): method = SelfSupervisedMethod(config) # logger = TensorBoardLogger("tb_logs", name=f"{name}_{seed}") # Define the checkpoint callback to save the best model only at the end of training # checkpoint_callback = ModelCheckpoint( # filename=f'last-v{seed}', # dirpath='checkpoint_semes', # Directory to save the checkpoints # every_n_epochs=10, # # every_n_train_steps=2, # save_last=True, # Save the last model # save_top_k=1, # save_weights_only=True, # Only save the model weights (not the optimizer, etc.) # # save_on_train_epoch_end=True # Save only at the end of the training epoch # ) trainer = Trainer(accelerator="gpu", devices=[1], max_epochs=num_epochs) # strategy=DDPStrategy(find_unused_parameters=False)) # to enable multi-gpu, but not necessary for now print("--------------------------------------") print(data_params[selector]['checkpoint']) train_loader = CustomDataloader(img_dir=data_params[selector]['train_path'], img_size=img_size, batch_size=batch_size, ).get_dataloader() trainer.fit(model=method, train_dataloaders=train_loader) trainer.save_checkpoint(data_params[selector]['checkpoint']) if __name__ == "__main__": main()