domain_mapping/vicreg/train.py

388 lines
11 KiB
Python

# %%
# %%
from torch.utils.data import Dataset, DataLoader
# from datasets import load_from_disk
import os
import json
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
from dataclasses import dataclass
import re
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModel,
DataCollatorWithPadding,
Trainer,
EarlyStoppingCallback,
TrainingArguments,
TrainerCallback
)
import evaluate
import numpy as np
import pandas as pd
from functools import partial
import warnings
from tqdm import tqdm
from dataload import DynamicDataset, custom_collate_fn
from sklearn.neighbors import KNeighborsClassifier
torch.set_float32_matmul_precision('high')
def set_seed(seed):
"""
Set the random seed for reproducibility.
"""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch GPU
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
set_seed(42)
SAMPLES=40
BATCH_SIZE=256
# %%
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
# make our dataset and dataloader
# MODEL_NAME = "distilbert/distilbert-base-uncased"
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=False)
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=custom_collate_fn_with_tokenizer
)
# %%
# enable BERT with projection layer
class VICRegProjection(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(VICRegProjection, self).__init__()
self.projection = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.projection(x)
class BertWithVICReg(nn.Module):
def __init__(self, bert_model, projection_dim=256):
super(BertWithVICReg, self).__init__()
self.bert = bert_model
hidden_size = bert_model.config.hidden_size
self.projection = VICRegProjection(input_dim=hidden_size, hidden_dim=hidden_size, output_dim=projection_dim)
def forward(self, input_ids, attention_mask=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:,0,:]
projected_embeddings = self.projection(pooled_output)
return projected_embeddings
#######################################################
# %%
@dataclass
class Hyperparameters:
loss_constant_factor: float = 1
invariance_loss_weight: float = 25.0
variance_loss_weight: float = 25.0
covariance_loss_weight: float = 1.0
variance_loss_epsilon: float = 1e-5
# compute vicreg loss
def get_vicreg_loss(z_a, z_b, hparams):
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
# invariance loss
loss_inv = F.mse_loss(z_a, z_b)
# variance loss
std_z_a = torch.sqrt(z_a.var(dim=0) + hparams.variance_loss_epsilon)
std_z_b = torch.sqrt(z_b.var(dim=0) + hparams.variance_loss_epsilon)
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
loss_v_b = torch.mean(F.relu(1 - std_z_b))
loss_var = loss_v_a + loss_v_b
# covariance loss
N, D = z_a.shape
z_a = z_a - z_a.mean(dim=0)
z_b = z_b - z_b.mean(dim=0)
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
loss_cov = loss_c_a + loss_c_b
weighted_inv = loss_inv * hparams.invariance_loss_weight
weighted_var = loss_var * hparams.variance_loss_weight
weighted_cov = loss_cov * hparams.covariance_loss_weight
loss = weighted_inv + weighted_var + weighted_cov
return loss
# %%
#
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
bert_model = AutoModel.from_pretrained(MODEL_NAME)
# bert_hidden_size = bert_model.config.hidden_size
# projection_model = VICRegProjection(
# input_dim=bert_hidden_size,
# hidden_dim=bert_hidden_size,
# output_dim=256
# )
# need to allocate individual component of the model
# bert_model.to(DEVICE)
# projection_model.to(DEVICE)
model = BertWithVICReg(bert_model, projection_dim=128)
model.to(DEVICE)
# params = list(bert_model.parameters()) + list(projection_model.parameters())
params = model.parameters()
optimizer = torch.optim.AdamW(params, lr=5e-6)
hparams = Hyperparameters()
losses = []
# # %%
# batch = next(iter(dataloader))
# input_ids_0 = batch['input_ids_0'].to(DEVICE)
# attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
#
# # %%
# # outputs from reprojection layer
# bert_output = model(
# input_ids=input_ids_0,
# attention_mask=attention_mask_0
# )
# %%
# parameters
epochs = 80
for epoch in tqdm(range(epochs)):
dataset.regenerate_data()
for batch in dataloader:
optimizer.zero_grad()
# compute cls 0
input_ids_0 = batch['input_ids_0'].to(DEVICE)
attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
# outputs from reprojection layer
outputs_0 = model(
input_ids=input_ids_0,
attention_mask=attention_mask_0
)
# compute cls 1
input_ids_1 = batch['input_ids_1'].to(DEVICE)
attention_mask_1 = batch['attention_mask_1'].to(DEVICE)
# outputs from reprojection layer
outputs_1 = model(
input_ids=input_ids_1,
attention_mask=attention_mask_1
)
loss = get_vicreg_loss(outputs_0, outputs_1, hparams=hparams)
loss.backward()
optimizer.step()
# print(epoch, loss)
losses.append(loss)
torch.cuda.empty_cache()
print(loss.detach().item())
# %%
torch.save(model.state_dict(), './checkpoint/simple.pt')
####################################################
# %%
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
state_dict = torch.load('./checkpoint/simple.pt')
model = BertWithVICReg(bert_model, projection_dim=256)
model.load_state_dict(state_dict)
# %%
# Step 2: Load the state dictionary
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# state_dict = torch.load('./checkpoint/siamese.pt')
# model = model.bert
# %%
# Step 3: Apply the state dictionary to the model
model.to(DEVICE)
model.eval()
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
with open('../esAppMod/infer.json', 'r') as file:
test = json.load(file)
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
y_test = [d['entity_id'] for _, d in test['data'].items()]
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
train_entities = [preprocess_text(element) for element in train_entities]
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(train_entities, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
# output = outputs.last_hidden_state[:,0,:]
outputs = outputs.detach().cpu().numpy()
embedding_list.append(outputs)
cls = np.concatenate(embedding_list)
# %%
torch.cuda.empty_cache()
# %%
batches = batch_list(x_test, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
# output = outputs.last_hidden_state[:,0,:]
outputs = outputs.detach().cpu().numpy()
embedding_list.append(outputs)
cls_test = np.concatenate(embedding_list)
# %%
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
n_neighbors = [1, 3, 5, 10]
for n in n_neighbors:
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
num = 0
for a,b in zip(y_test, indices):
b = [labels[i] for i in b]
if a in b:
num += 1
print(f'Top-{n:<3} accuracy: {num / len(y_test)}')
print(np.min(distances), np.max(distances))
# with open("results/output.txt", "w") as f:
# for n in n_neighbors:
# distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
# num = 0
# for a,b in zip(y_test, indices):
# b = [labels[i] for i in b]
# if a in b:
# num += 1
# print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
# print(np.min(distances), np.max(distances), file=f)
# %%
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# %%
# Reduce dimensions with t-SNE
tsne = TSNE(n_components=2, random_state=42)
embeddings= cls
embeddings_reduced = tsne.fit_transform(embeddings)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embeddings_reduced[:, 0], embeddings_reduced[:, 1], c=labels, cmap='viridis', alpha=0.6)
plt.colorbar(scatter)
plt.xlabel('Component 1')
plt.ylabel('Component 2')
plt.title('Visualization of Embeddings')
plt.show()
# %%