# %% import torch import json import random import numpy as np import pandas as pd from transformers import AutoTokenizer from transformers import AutoModel from loss import batch_all_triplet_loss, batch_hard_triplet_loss from sklearn.neighbors import KNeighborsClassifier from tqdm import tqdm # parallel utilities import torch.distributed as dist from torch.utils.data import Dataset, DataLoader, DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import os # %% with open('../esAppMod/tca_entities.json', 'r') as file: entities = json.load(file) # produces a dictionary map from entity_id to entity_name all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} with open('../esAppMod/train.json', 'r') as file: train = json.load(file) # map from entity_id to list of mentions train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} # map from entity_id to list of entity_names train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} # %% def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): # split entity mentions into groups # anchor = False, don't add entity name to each group, simply treat it as a normal mention entity_sets = [] if anchor: for id, mentions in entity_id_mentions.items(): random.shuffle(mentions) positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] entity_sets.extend(anchor_positive) else: for id, mentions in entity_id_mentions.items(): group = list(set([entity_id_name[id]] + mentions)) random.shuffle(group) positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] entity_sets.extend(positives) # entity sets are always ([list of mentions], id) # to convert it to dataset form, we will just use a dataframe id_mention_pairs = [] for entity in entity_sets: entity_id = entity[1] for mention in entity[0]: id_mention_pairs.append({ 'entity_id': entity_id, 'mention': mention }) df = pd.DataFrame(id_mention_pairs) return df # def batchGenerator(data, batch_size): # for i in range(0, len(data), batch_size): # batch = data[i:i+batch_size] # x, y = [], [] # for t in batch: # x.extend(t[0]) # y.extend([t[1]]*len(t[0])) # yield x, y class CustomDataset(Dataset): def __init__(self, df): self.data = df # data should be preprocessed if necessary before being passed here def __len__(self): return len(self.data) def __getitem__(self, idx): # Return the data and label as tuples entry = self.data.iloc[idx] x = entry['mention'] y = entry['entity_id'] return x,y # %% num_sample_per_class = 10 # samples in each group batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = 2 epochs = 200 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' MODEL_NAME = 'prajjwal1/bert-small' #'bert-base-cased' 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # model.to(DEVICE) # model.train() # # losses = [] # %% # for epoch in tqdm(range(epochs)): # data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) # dataset = CustomDataset(data) # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # for x, y in dataloader: # # print(len(x), len(y), end='-->') # optimizer.zero_grad() # inputs = tokenizer(x, padding=True, return_tensors='pt') # inputs = inputs.to(DEVICE) # outputs = model(**inputs) # cls = outputs.last_hidden_state[:,0,:] # # for training less than half the time, train on easy # if epoch < epochs / 2: # loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) # # for training after half the time, train on hard # else: # loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) # loss.backward() # optimizer.step() # # print(epoch, loss) # losses.append(loss) # del inputs, outputs, cls, loss # torch.cuda.empty_cache() # # torch.save(model.state_dict(), './checkpoint/siamese.pt') # %% def save_checkpoint(model, optimizer, epoch, path, rank): if rank == 0: # Only save on the master process # Save only the underlying model's state_dict, not the DDP wrapper torch.save({ 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch }, path) # %% def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def reduce_mean(tensor, nprocs): """ Reduces and averages the tensor across all processes. This function reduces a tensor from all processes to all processes. The resulting tensor is identical in all processes. """ rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt def train(rank, world_size): setup(rank, world_size) # Setup model, DataLoader with DistributedSampler model = AutoModel.from_pretrained(MODEL_NAME) model = model.cuda(rank) model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # initialize progress bar # Initialize tqdm on the master process if torch.distributed.get_rank() == 0: # Only print from the master process pbar = tqdm(total=epochs, desc='batch progress') for epoch in range(epochs): total_loss = 0.0 num_batches = 0 data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) train_dataset = CustomDataset(data) train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler) device = torch.device(f"cuda:{rank}") train_sampler.set_epoch(epoch) for data, targets in train_loader: # data, targets = data.cuda(rank), targets.cuda(rank) optimizer.zero_grad() inputs = tokenizer(data, padding=True, return_tensors='pt') inputs = inputs.to(device) outputs = model(**inputs) cls = outputs.last_hidden_state[:,0,:] # for training less than half the time, train on easy if epoch < epochs / 2: loss, _ = batch_all_triplet_loss(targets.to(device), cls, margin, squared=False) # for training after half the time, train on hard else: loss = batch_hard_triplet_loss(targets.to(device), cls, margin, squared=False) loss.backward() optimizer.step() # Reduce and average the loss across all processes reduced_loss = reduce_mean(loss, world_size) total_loss += reduced_loss.item() num_batches += 1 # print(epoch, loss) # losses.append(loss) del inputs, outputs, cls, loss torch.cuda.empty_cache() dist.barrier() # Close tqdm bar on master process if torch.distributed.get_rank() == 0: # Only print from the master process pbar.update(epoch) epoch_loss = total_loss / num_batches tqdm.write(f'loss: {epoch_loss}') if torch.distributed.get_rank() == 0: # Only print from the master process pbar.close() path = './checkpoint/siamese.pt' torch.save(model.module.state_dict(), path) cleanup() if __name__ == '__main__': # Set the number of processes to the number of GPUs available world_size = torch.cuda.device_count() # Use torch.multiprocessing.spawn to launch the processes torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)