import os, random from collections import defaultdict import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler import torch.multiprocessing as mp from transformers import AutoTokenizer, AutoModel from data import generate_train_entity_sets from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work from loss import batch_all_triplet_loss, batch_hard_triplet_loss from sklearn.neighbors import KNeighborsClassifier import numpy as np import logging def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size=32, pin_memory=True, num_workers=8): dataset = generate_train_entity_sets(vocab_entity_id_mentions, entity_id_name=None, group_size=num_sample_per_class, anchor=False) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) return DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler) def test_dataloader(test_mentions, batch_size=32): return DataLoader(test_mentions, batch_size=batch_size, shuffle=False) def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, margin): # DEVICE = torch.device(f"cuda:{dist.get_rank()}") DEVICE = torch.device(f'cuda:{rank}') model.train() epoch_loss, epoch_len = [epoch], [epoch] for groups in tqdm(train_dataloader, desc =f'Training batches on {DEVICE}'): groups[0][:] = zip(*groups[0][::-1]) x, y = [], [] for mention, label in zip(groups[0], groups[1]): mention = [m for m in mention if m != 'PAD'] x.extend(mention) y.extend([label.item()]*len(mention)) optimizer.zero_grad() inputs = tokenizer(x, padding=True, return_tensors='pt') inputs = inputs.to(DEVICE) cls = model(inputs) # cls = torch.nn.functional.normalize(cls) ## normalize cls embedding before computing loss, didn't work # cls = torch.nn.Dropout(p= 0.25)(cls) ## add dropout, didn't work # loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True) # loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True) if epoch < epochs / 2: # if epoch // (epochs / 4) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage # if (epoch // 10) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) else: loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) #### tried circle loss, no obvious advantage loss.backward() optimizer.step() # logging.info(f'{epoch} {len(x)} {loss.item()}') epoch_loss.append(loss.item()) epoch_len.append(len(x)) # del inputs, cls, loss # torch.cuda.empty_cache() logging.info(f'{DEVICE}{epoch_len}') logging.info(f'{DEVICE}{epoch_loss}') def check_label(predicted_cui: str, golden_cui: str) -> int: """ Some composite annotation didn't consider orders So, set label '1' if any cui is matched within composite cui (or single cui) Otherwise, set label '0' """ return int(len(set(predicted_cui.replace('+', '|').split("|")).intersection(set(golden_cui.replace('+', '|').split("|"))))>0) def getEmbeddings(mentions, model, tokenizer, DEVICE, batch_size=200): model.to(DEVICE) model.eval() dataloader = DataLoader(mentions, batch_size, shuffle=False) embeddings = np.empty((0, 768), np.float32) with torch.no_grad(): for mentions in tqdm(dataloader, desc ='Getting embeddings'): inputs = tokenizer(mentions, padding=True, return_tensors='pt') inputs = inputs.to(DEVICE) cls = model(inputs) embeddings = np.append(embeddings, cls.detach().cpu().numpy(), axis=0) # del inputs, cls # torch.cuda.empty_cache() return embeddings def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, model, tokenizer): DEVICE = torch.device(f'cuda:{rank}') vocab_embeddings = getEmbeddings(vocab_mentions, model, tokenizer, DEVICE) test_embeddings = getEmbeddings(test_mentions, model, tokenizer, DEVICE) knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(vocab_embeddings, vocab_ids) n_neighbors = [1, 3, 5, 10] res = [] for n in n_neighbors: distances, indices = knn.kneighbors(test_embeddings, n_neighbors=n) num = 0 for gold_cui, idx in zip(test_cuis, indices): candidates = [id_to_cui[vocab_ids[i]] for i in idx] for c in candidates: if check_label(c, gold_cui): num += 1 break res.append(num / len(test_cuis)) # print(f'Top-{n:<3} accuracy: {num / len(test_cuis)}') return res # print(np.min(distances), np.max(distances)) def save_checkpoint(model, res, epoch, dataName): logging.info(f'Saving model {epoch} {res} ') torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt') class Model(nn.Module): def __init__(self,MODEL_NAME): super(Model, self).__init__() self.model = AutoModel.from_pretrained(MODEL_NAME) def forward(self, inputs): outputs = self.model(**inputs) cls = outputs.last_hidden_state[:,0,:] return cls def main(rank, world_size, config): print(f"Running main(**args) on rank {rank}.") setup(rank, world_size) dataName = config['DEFAULT']['dataName'] logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO) vocab = defaultdict(set) with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f: for line in f: vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower()) cui_to_id, id_to_cui = {}, {} vocab_entity_id_mentions = {} for id, cui in enumerate(vocab): cui_to_id[cui] = id id_to_cui[id] = cui for cui, mention in vocab.items(): vocab_entity_id_mentions[cui_to_id[cui]] = mention vocab_mentions, vocab_ids = [], [] for id, mentions in vocab_entity_id_mentions.items(): vocab_mentions.extend(mentions) vocab_ids.extend([id]*len(mentions)) test_mentions, test_cuis = [], [] with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f: for line in f: test_cuis.append(line.strip().split('||')[-1]) test_mentions.append(line.strip().split('||')[-2].lower()) num_sample_per_class = int(config['data']['group_size']) # samples in each group batch_size = int(config['train']['batch_size']) # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = int(config['model']['margin']) epochs = int(config['train']['epochs']) lr = float(config['train']['lr']) MODEL_NAME = config['model']['model_name'] trainDataLoader = train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size, pin_memory=False, num_workers=0) # test_dataloader = test_dataloader(test_mentions, batch_size=200) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = Model(MODEL_NAME).to(rank) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) best = 0 if rank == 0: logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' ) for epoch in tqdm(range(epochs)): trainDataLoader.sampler.set_epoch(epoch) train(rank, epoch, epochs, trainDataLoader, ddp_model, optimizer, tokenizer, margin) # if rank == 0 and epoch % 2 == 0: if rank == 0: res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer) logging.info(f'{epoch} {res}') if res[0] > best: best = res[0] save_checkpoint(ddp_model.module, res, epoch, dataName) dist.barrier() cleanup() if __name__ == '__main__': import configparser config = configparser.ConfigParser() config.read('config.ini') world_size = torch.cuda.device_count() print(f"You have {world_size} GPUs.") mp.spawn( main, args=(world_size, config), nprocs=world_size, join=True )