81 lines
2.9 KiB
Python
81 lines
2.9 KiB
Python
from attr import evolve
|
|
# from lightning.pytorch.callbacks import ModelCheckpoint
|
|
# from lightning.pytorch.loggers import TensorBoardLogger
|
|
from lightning.pytorch import Trainer
|
|
|
|
from vicreg import SelfSupervisedMethod
|
|
from model_params import VICRegParams
|
|
# import utils
|
|
|
|
# from torch.utils.data import DataLoader
|
|
from dataload import CustomDataloader
|
|
import torch
|
|
import numpy as np
|
|
from utils import get_random_file
|
|
|
|
# speedup
|
|
torch.set_float32_matmul_precision('medium')
|
|
|
|
# data parameters
|
|
data_params = list()
|
|
data_params.append({'train_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/train",
|
|
'test_path' :"/home/richard/Projects/06_research/semes_gaf/gaf_data/test",
|
|
'checkpoint': 'checkpoint_semes/last.ckpt'})
|
|
|
|
batch_size = 256
|
|
num_epochs=20
|
|
selector = 0
|
|
|
|
def main():
|
|
|
|
configs = {
|
|
"vicreg": evolve(VICRegParams(),
|
|
encoder_arch = "ws_resnet18", # resnet18, resnet34, resnet50
|
|
max_epochs=num_epochs
|
|
),
|
|
}
|
|
train_path = data_params[selector]['train_path']
|
|
random_file = get_random_file(train_path)
|
|
img_size = len(np.load(random_file)['data'])
|
|
|
|
|
|
for seed in range(1): # number of repeats
|
|
for name, config in configs.items():
|
|
|
|
|
|
method = SelfSupervisedMethod(config)
|
|
# logger = TensorBoardLogger("tb_logs", name=f"{name}_{seed}")
|
|
# Define the checkpoint callback to save the best model only at the end of training
|
|
# checkpoint_callback = ModelCheckpoint(
|
|
# filename=f'last-v{seed}',
|
|
# dirpath='checkpoint_semes', # Directory to save the checkpoints
|
|
# every_n_epochs=10,
|
|
# # every_n_train_steps=2,
|
|
# save_last=True, # Save the last model
|
|
# save_top_k=1,
|
|
# save_weights_only=True, # Only save the model weights (not the optimizer, etc.)
|
|
# # save_on_train_epoch_end=True # Save only at the end of the training epoch
|
|
# )
|
|
|
|
|
|
|
|
trainer = Trainer(accelerator="gpu",
|
|
devices=[1],
|
|
max_epochs=num_epochs)
|
|
# strategy=DDPStrategy(find_unused_parameters=False)) # to enable multi-gpu, but not necessary for now
|
|
|
|
print("--------------------------------------")
|
|
print(data_params[selector]['checkpoint'])
|
|
|
|
|
|
train_loader = CustomDataloader(img_dir=data_params[selector]['train_path'],
|
|
img_size=img_size,
|
|
batch_size=batch_size,
|
|
).get_dataloader()
|
|
trainer.fit(model=method, train_dataloaders=train_loader)
|
|
trainer.save_checkpoint(data_params[selector]['checkpoint'])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|