242 lines
8.7 KiB
Python
242 lines
8.7 KiB
Python
# %%
|
|
import torch
|
|
import json
|
|
import random
|
|
import numpy as np
|
|
import pandas as pd
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoModel
|
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
from tqdm import tqdm
|
|
# parallel utilities
|
|
import torch.distributed as dist
|
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
import os
|
|
|
|
# %%
|
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
|
entities = json.load(file)
|
|
# produces a dictionary map from entity_id to entity_name
|
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
|
|
|
with open('../esAppMod/train.json', 'r') as file:
|
|
train = json.load(file)
|
|
# map from entity_id to list of mentions
|
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
|
# map from entity_id to list of entity_names
|
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
|
|
|
|
|
|
|
# %%
|
|
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
|
# split entity mentions into groups
|
|
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
|
entity_sets = []
|
|
if anchor:
|
|
for id, mentions in entity_id_mentions.items():
|
|
random.shuffle(mentions)
|
|
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
|
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
|
entity_sets.extend(anchor_positive)
|
|
else:
|
|
for id, mentions in entity_id_mentions.items():
|
|
group = list(set([entity_id_name[id]] + mentions))
|
|
random.shuffle(group)
|
|
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
|
entity_sets.extend(positives)
|
|
|
|
# entity sets are always ([list of mentions], id)
|
|
# to convert it to dataset form, we will just use a dataframe
|
|
id_mention_pairs = []
|
|
for entity in entity_sets:
|
|
entity_id = entity[1]
|
|
for mention in entity[0]:
|
|
id_mention_pairs.append({
|
|
'entity_id': entity_id,
|
|
'mention': mention
|
|
})
|
|
df = pd.DataFrame(id_mention_pairs)
|
|
return df
|
|
|
|
# def batchGenerator(data, batch_size):
|
|
# for i in range(0, len(data), batch_size):
|
|
# batch = data[i:i+batch_size]
|
|
# x, y = [], []
|
|
# for t in batch:
|
|
# x.extend(t[0])
|
|
# y.extend([t[1]]*len(t[0]))
|
|
# yield x, y
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, df):
|
|
self.data = df # data should be preprocessed if necessary before being passed here
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
# Return the data and label as tuples
|
|
entry = self.data.iloc[idx]
|
|
x = entry['mention']
|
|
y = entry['entity_id']
|
|
return x,y
|
|
|
|
|
|
|
|
|
|
# %%
|
|
|
|
num_sample_per_class = 10 # samples in each group
|
|
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
|
margin = 2
|
|
epochs = 200
|
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
|
MODEL_NAME = 'prajjwal1/bert-small' #'bert-base-cased' 'distilbert-base-cased'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
model = AutoModel.from_pretrained(MODEL_NAME)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
|
|
|
# model.to(DEVICE)
|
|
# model.train()
|
|
#
|
|
# losses = []
|
|
|
|
# %%
|
|
# for epoch in tqdm(range(epochs)):
|
|
# data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
|
# dataset = CustomDataset(data)
|
|
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
# for x, y in dataloader:
|
|
# # print(len(x), len(y), end='-->')
|
|
# optimizer.zero_grad()
|
|
# inputs = tokenizer(x, padding=True, return_tensors='pt')
|
|
# inputs = inputs.to(DEVICE)
|
|
# outputs = model(**inputs)
|
|
# cls = outputs.last_hidden_state[:,0,:]
|
|
# # for training less than half the time, train on easy
|
|
# if epoch < epochs / 2:
|
|
# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
|
# # for training after half the time, train on hard
|
|
# else:
|
|
# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
|
# loss.backward()
|
|
# optimizer.step()
|
|
# # print(epoch, loss)
|
|
# losses.append(loss)
|
|
# del inputs, outputs, cls, loss
|
|
# torch.cuda.empty_cache()
|
|
#
|
|
# torch.save(model.state_dict(), './checkpoint/siamese.pt')
|
|
# %%
|
|
|
|
def save_checkpoint(model, optimizer, epoch, path, rank):
|
|
if rank == 0: # Only save on the master process
|
|
# Save only the underlying model's state_dict, not the DDP wrapper
|
|
torch.save({
|
|
'model_state_dict': model.module.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'epoch': epoch
|
|
}, path)
|
|
|
|
# %%
|
|
def setup(rank, world_size):
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
os.environ['MASTER_PORT'] = '12355'
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
def reduce_mean(tensor, nprocs):
|
|
"""
|
|
Reduces and averages the tensor across all processes.
|
|
This function reduces a tensor from all processes to all processes.
|
|
The resulting tensor is identical in all processes.
|
|
"""
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
rt /= nprocs
|
|
return rt
|
|
|
|
def train(rank, world_size):
|
|
setup(rank, world_size)
|
|
|
|
# Setup model, DataLoader with DistributedSampler
|
|
|
|
model = AutoModel.from_pretrained(MODEL_NAME)
|
|
model = model.cuda(rank)
|
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
|
|
|
# initialize progress bar
|
|
# Initialize tqdm on the master process
|
|
if torch.distributed.get_rank() == 0: # Only print from the master process
|
|
pbar = tqdm(total=epochs, desc='batch progress')
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
|
train_dataset = CustomDataset(data)
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
|
|
|
|
device = torch.device(f"cuda:{rank}")
|
|
train_sampler.set_epoch(epoch)
|
|
for data, targets in train_loader:
|
|
# data, targets = data.cuda(rank), targets.cuda(rank)
|
|
optimizer.zero_grad()
|
|
inputs = tokenizer(data, padding=True, return_tensors='pt')
|
|
inputs = inputs.to(device)
|
|
outputs = model(**inputs)
|
|
cls = outputs.last_hidden_state[:,0,:]
|
|
# for training less than half the time, train on easy
|
|
if epoch < epochs / 2:
|
|
loss, _ = batch_all_triplet_loss(targets.to(device), cls, margin, squared=False)
|
|
# for training after half the time, train on hard
|
|
else:
|
|
loss = batch_hard_triplet_loss(targets.to(device), cls, margin, squared=False)
|
|
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Reduce and average the loss across all processes
|
|
reduced_loss = reduce_mean(loss, world_size)
|
|
total_loss += reduced_loss.item()
|
|
num_batches += 1
|
|
|
|
# print(epoch, loss)
|
|
# losses.append(loss)
|
|
del inputs, outputs, cls, loss
|
|
torch.cuda.empty_cache()
|
|
|
|
dist.barrier()
|
|
# Close tqdm bar on master process
|
|
if torch.distributed.get_rank() == 0: # Only print from the master process
|
|
pbar.update(epoch)
|
|
epoch_loss = total_loss / num_batches
|
|
tqdm.write(f'loss: {epoch_loss}')
|
|
|
|
if torch.distributed.get_rank() == 0: # Only print from the master process
|
|
pbar.close()
|
|
path = './checkpoint/siamese.pt'
|
|
torch.save(model.module.state_dict(), path)
|
|
|
|
cleanup()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Set the number of processes to the number of GPUs available
|
|
world_size = torch.cuda.device_count()
|
|
# Use torch.multiprocessing.spawn to launch the processes
|
|
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True) |