2025-01-18 12:14:06 +09:00
|
|
|
import os, random
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
|
|
|
|
from data import generate_train_entity_sets
|
|
|
|
|
2025-01-18 23:53:08 +09:00
|
|
|
from tqdm import tqdm
|
2025-01-18 12:14:06 +09:00
|
|
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
|
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
def setup(rank, world_size):
|
|
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
|
|
|
os.environ['MASTER_PORT'] = '12355'
|
|
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
|
|
torch.cuda.set_device(rank)
|
|
|
|
|
|
|
|
def cleanup():
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
|
|
def train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size=32, pin_memory=True, num_workers=8):
|
|
|
|
dataset = generate_train_entity_sets(vocab_entity_id_mentions, entity_id_name=None, group_size=num_sample_per_class, anchor=False)
|
|
|
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
|
|
|
return DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
|
|
|
|
|
|
|
|
def test_dataloader(test_mentions, batch_size=32):
|
|
|
|
return DataLoader(test_mentions, batch_size=batch_size, shuffle=False)
|
|
|
|
|
|
|
|
def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, margin):
|
|
|
|
# DEVICE = torch.device(f"cuda:{dist.get_rank()}")
|
|
|
|
DEVICE = torch.device(f'cuda:{rank}')
|
|
|
|
model.train()
|
|
|
|
epoch_loss, epoch_len = [epoch], [epoch]
|
|
|
|
|
|
|
|
for groups in tqdm(train_dataloader, desc =f'Training batches on {DEVICE}'):
|
|
|
|
groups[0][:] = zip(*groups[0][::-1])
|
|
|
|
x, y = [], []
|
|
|
|
for mention, label in zip(groups[0], groups[1]):
|
|
|
|
mention = [m for m in mention if m != 'PAD']
|
|
|
|
x.extend(mention)
|
|
|
|
y.extend([label.item()]*len(mention))
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
|
|
|
inputs = inputs.to(DEVICE)
|
|
|
|
cls = model(inputs)
|
|
|
|
# cls = torch.nn.functional.normalize(cls) ## normalize cls embedding before computing loss, didn't work
|
|
|
|
# cls = torch.nn.Dropout(p= 0.25)(cls) ## add dropout, didn't work
|
|
|
|
|
|
|
|
# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True)
|
|
|
|
# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True)
|
|
|
|
|
|
|
|
if epoch < epochs / 2:
|
|
|
|
# if epoch // (epochs / 4) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage
|
|
|
|
# if (epoch // 10) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage
|
|
|
|
loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
|
|
|
else:
|
|
|
|
loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
|
|
|
|
|
|
|
#### tried circle loss, no obvious advantage
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
epoch_loss.append(loss.item())
|
|
|
|
epoch_len.append(len(x))
|
|
|
|
# del inputs, cls, loss
|
|
|
|
# torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
def check_label(predicted_cui: str, golden_cui: str) -> int:
|
|
|
|
"""
|
|
|
|
Some composite annotation didn't consider orders
|
|
|
|
So, set label '1' if any cui is matched within composite cui (or single cui)
|
|
|
|
Otherwise, set label '0'
|
|
|
|
"""
|
|
|
|
return int(len(set(predicted_cui.replace('+', '|').split("|")).intersection(set(golden_cui.replace('+', '|').split("|"))))>0)
|
|
|
|
|
|
|
|
def getEmbeddings(mentions, model, tokenizer, DEVICE, batch_size=200):
|
|
|
|
model.to(DEVICE)
|
|
|
|
model.eval()
|
|
|
|
dataloader = DataLoader(mentions, batch_size, shuffle=False)
|
|
|
|
embeddings = np.empty((0, 768), np.float32)
|
|
|
|
with torch.no_grad():
|
|
|
|
for mentions in tqdm(dataloader, desc ='Getting embeddings'):
|
|
|
|
inputs = tokenizer(mentions, padding=True, return_tensors='pt')
|
|
|
|
inputs = inputs.to(DEVICE)
|
|
|
|
cls = model(inputs)
|
|
|
|
embeddings = np.append(embeddings, cls.detach().cpu().numpy(), axis=0)
|
|
|
|
# del inputs, cls
|
|
|
|
# torch.cuda.empty_cache()
|
|
|
|
return embeddings
|
|
|
|
|
|
|
|
def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, model, tokenizer):
|
|
|
|
DEVICE = torch.device(f'cuda:{rank}')
|
|
|
|
vocab_embeddings = getEmbeddings(vocab_mentions, model, tokenizer, DEVICE)
|
|
|
|
test_embeddings = getEmbeddings(test_mentions, model, tokenizer, DEVICE)
|
|
|
|
|
|
|
|
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(vocab_embeddings, vocab_ids)
|
|
|
|
n_neighbors = [1, 3, 5, 10]
|
|
|
|
res = []
|
|
|
|
for n in n_neighbors:
|
|
|
|
distances, indices = knn.kneighbors(test_embeddings, n_neighbors=n)
|
|
|
|
num = 0
|
|
|
|
for gold_cui, idx in zip(test_cuis, indices):
|
|
|
|
candidates = [id_to_cui[vocab_ids[i]] for i in idx]
|
|
|
|
for c in candidates:
|
|
|
|
if check_label(c, gold_cui):
|
|
|
|
num += 1
|
|
|
|
break
|
|
|
|
res.append(num / len(test_cuis))
|
|
|
|
# print(f'Top-{n:<3} accuracy: {num / len(test_cuis)}')
|
|
|
|
return res
|
|
|
|
# print(np.min(distances), np.max(distances))
|
|
|
|
|
|
|
|
def save_checkpoint(model, res, epoch, dataName):
|
2025-01-18 23:53:08 +09:00
|
|
|
torch.save(model.state_dict(), './checkpoint/' + dataName + '.pt')
|
2025-01-18 12:14:06 +09:00
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self,MODEL_NAME):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.model = AutoModel.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
outputs = self.model(**inputs)
|
|
|
|
cls = outputs.last_hidden_state[:,0,:]
|
|
|
|
return cls
|
|
|
|
|
|
|
|
def main(rank, world_size, config):
|
|
|
|
|
|
|
|
print(f"Running main(**args) on rank {rank}.")
|
|
|
|
setup(rank, world_size)
|
|
|
|
|
|
|
|
dataName = config['DEFAULT']['dataName']
|
|
|
|
|
|
|
|
vocab = defaultdict(set)
|
2025-01-18 23:53:08 +09:00
|
|
|
with open('../biomedical/' + dataName + '/' + config['train']['dictionary']) as f:
|
2025-01-18 12:14:06 +09:00
|
|
|
for line in f:
|
2025-01-18 23:53:08 +09:00
|
|
|
line_list = line.strip().split('||')
|
|
|
|
vocab[line_list[0]].add(line_list[1].lower())
|
2025-01-18 12:14:06 +09:00
|
|
|
|
|
|
|
cui_to_id, id_to_cui = {}, {}
|
|
|
|
vocab_entity_id_mentions = {}
|
|
|
|
for id, cui in enumerate(vocab):
|
|
|
|
cui_to_id[cui] = id
|
|
|
|
id_to_cui[id] = cui
|
|
|
|
for cui, mention in vocab.items():
|
|
|
|
vocab_entity_id_mentions[cui_to_id[cui]] = mention
|
|
|
|
|
|
|
|
vocab_mentions, vocab_ids = [], []
|
|
|
|
for id, mentions in vocab_entity_id_mentions.items():
|
|
|
|
vocab_mentions.extend(mentions)
|
|
|
|
vocab_ids.extend([id]*len(mentions))
|
|
|
|
|
|
|
|
test_mentions, test_cuis = [], []
|
2025-01-18 23:53:08 +09:00
|
|
|
with open('../biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
|
2025-01-18 12:14:06 +09:00
|
|
|
for line in f:
|
|
|
|
test_cuis.append(line.strip().split('||')[-1])
|
|
|
|
test_mentions.append(line.strip().split('||')[-2].lower())
|
|
|
|
|
|
|
|
num_sample_per_class = int(config['data']['group_size']) # samples in each group
|
|
|
|
batch_size = int(config['train']['batch_size']) # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
|
|
|
margin = int(config['model']['margin'])
|
|
|
|
epochs = int(config['train']['epochs'])
|
|
|
|
lr = float(config['train']['lr'])
|
|
|
|
MODEL_NAME = config['model']['model_name']
|
|
|
|
|
|
|
|
trainDataLoader = train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size, pin_memory=False, num_workers=0)
|
|
|
|
# test_dataloader = test_dataloader(test_mentions, batch_size=200)
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
model = Model(MODEL_NAME).to(rank)
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
|
|
|
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
|
|
|
|
|
|
|
best = 0
|
2025-01-18 23:53:08 +09:00
|
|
|
best_res = []
|
2025-01-18 12:14:06 +09:00
|
|
|
|
|
|
|
for epoch in tqdm(range(epochs)):
|
|
|
|
trainDataLoader.sampler.set_epoch(epoch)
|
|
|
|
train(rank, epoch, epochs, trainDataLoader, ddp_model, optimizer, tokenizer, margin)
|
|
|
|
# if rank == 0 and epoch % 2 == 0:
|
|
|
|
if rank == 0:
|
|
|
|
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
|
|
|
|
if res[0] > best:
|
|
|
|
best = res[0]
|
2025-01-18 23:53:08 +09:00
|
|
|
best_res = res
|
2025-01-18 12:14:06 +09:00
|
|
|
save_checkpoint(ddp_model.module, res, epoch, dataName)
|
2025-01-18 23:53:08 +09:00
|
|
|
with open("biomedical_results/output.txt", "a") as f:
|
|
|
|
print('new best ----', file=f)
|
|
|
|
for idx,n in enumerate([1,3,5,10]):
|
|
|
|
print(f'Top-{n:<3} accuracy: {best_res[idx]}', file=f)
|
|
|
|
|
2025-01-18 12:14:06 +09:00
|
|
|
dist.barrier()
|
2025-01-18 23:53:08 +09:00
|
|
|
|
|
|
|
|
2025-01-18 12:14:06 +09:00
|
|
|
cleanup()
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
import configparser
|
|
|
|
config = configparser.ConfigParser()
|
|
|
|
config.read('config.ini')
|
|
|
|
world_size = torch.cuda.device_count()
|
|
|
|
print(f"You have {world_size} GPUs.")
|
|
|
|
mp.spawn(
|
|
|
|
main,
|
|
|
|
args=(world_size, config),
|
|
|
|
nprocs=world_size,
|
|
|
|
join=True
|
|
|
|
)
|