added cosines code
This commit is contained in:
parent
b6cf2d4416
commit
94ee7beba7
|
@ -23,7 +23,7 @@ for _, row in entity_df.iterrows():
|
|||
train_df.sort_values(by=['entity_id']).to_markdown('out.md')
|
||||
|
||||
# %%
|
||||
data_path = '../train/class_bert_augmentation/prediction/exports/result.csv'
|
||||
data_path = '../esAppMod_train/class_bert_augmentation/prediction/exports/result.csv'
|
||||
prediction_df = pd.read_csv(data_path)
|
||||
|
||||
predicted_entity_list = []
|
||||
|
|
|
@ -25,7 +25,6 @@ from transformers import (
|
|||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import math
|
||||
from functools import partial
|
||||
import warnings
|
||||
|
||||
|
@ -55,14 +54,14 @@ set_seed(42)
|
|||
# %%
|
||||
# PARAMETERS
|
||||
SAMPLES=20
|
||||
SHUFFLES=5
|
||||
AMPLIFY_FACTOR=5
|
||||
SHUFFLES=3
|
||||
AMPLIFY_FACTOR=3
|
||||
|
||||
# %%
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../../esAppMod_data_import/train.csv'
|
||||
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
|
@ -192,6 +191,8 @@ def augment_data(df):
|
|||
return new_df
|
||||
|
||||
|
||||
|
||||
|
||||
###############################################################
|
||||
# regeneration code
|
||||
# %%
|
||||
|
@ -265,13 +266,20 @@ class DynamicDataset(Dataset):
|
|||
|
||||
# %%
|
||||
class RegenerateDatasetCallback(TrainerCallback):
|
||||
def __init__(self, dataset):
|
||||
def __init__(self, dataset, every_n_epochs=2):
|
||||
"""
|
||||
Args:
|
||||
dataset: The dataset instance that supports regeneration.
|
||||
every_n_epochs (int): Number of epochs to wait before regenerating the dataset.
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.every_n_epochs = every_n_epochs
|
||||
|
||||
def on_epoch_begin(self, args, state, control, **kwargs):
|
||||
print(f"Epoch {int(math.ceil(state.epoch + 1))}: Regenerating dataset")
|
||||
self.dataset.regenerate_data()
|
||||
|
||||
# Check if the current epoch is a multiple of `every_n_epochs`
|
||||
if (state.epoch + 1) % self.every_n_epochs == 0:
|
||||
print(f"Epoch {int(state.epoch + 1)}: Regenerating dataset...")
|
||||
self.dataset.regenerate_data()
|
||||
|
||||
|
||||
# %%
|
||||
|
@ -310,11 +318,11 @@ def train():
|
|||
|
||||
|
||||
# Define the callback
|
||||
lean_df = df.drop(columns=['entity_name'])
|
||||
dynamic_dataset = DynamicDataset(df = lean_df, sample_size_per_class=10, tokenizer=tokenizer)
|
||||
# lean_df = df.drop(columns=['entity_name'])
|
||||
dynamic_dataset = DynamicDataset(df = df, sample_size_per_class=SAMPLES, tokenizer=tokenizer)
|
||||
|
||||
# create the regeneration callback
|
||||
regeneration_callback = RegenerateDatasetCallback(dynamic_dataset)
|
||||
regeneration_callback = RegenerateDatasetCallback(dynamic_dataset, every_n_epochs=2)
|
||||
|
||||
# compute metrics
|
||||
metric = evaluate.load("accuracy")
|
||||
|
@ -346,18 +354,17 @@ def train():
|
|||
eval_strategy="no",
|
||||
logging_dir="tensorboard-log",
|
||||
logging_strategy="epoch",
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
# save_strategy="epoch",
|
||||
load_best_model_at_end=False,
|
||||
learning_rate=5e-5,
|
||||
per_device_train_batch_size=64,
|
||||
# per_device_eval_batch_size=64,
|
||||
learning_rate=1e-4,
|
||||
per_device_train_batch_size=256,
|
||||
# per_device_eval_batch_size=256,
|
||||
auto_find_batch_size=False,
|
||||
ddp_find_unused_parameters=False,
|
||||
weight_decay=0.01,
|
||||
save_total_limit=1,
|
||||
num_train_epochs=120,
|
||||
warmup_steps=400,
|
||||
num_train_epochs=80,
|
||||
warmup_steps=200,
|
||||
bf16=True,
|
||||
push_to_hub=False,
|
||||
remove_unused_columns=False,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
*******************************************************************************
|
||||
Accuracy: 0.80655
|
||||
F1 Score: 0.82821
|
||||
Precision: 0.87847
|
||||
Recall: 0.80655
|
||||
Accuracy: 0.77215
|
||||
F1 Score: 0.79997
|
||||
Precision: 0.87183
|
||||
Recall: 0.77215
|
||||
|
|
|
@ -33,7 +33,7 @@ BATCH_SIZE = 32
|
|||
|
||||
# %%
|
||||
# construct the target id list
|
||||
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
|
||||
data_path = '../../../../biomedical_data_import/bc2gm_train.csv'
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
entity_ids = train_df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
@ -62,6 +62,13 @@ def preprocess_text(text):
|
|||
return text
|
||||
|
||||
|
||||
def is_int_string(s):
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# outputs a list of dictionaries
|
||||
|
@ -72,9 +79,12 @@ def preprocess_text(text):
|
|||
def process_df_to_dict(df):
|
||||
output_list = []
|
||||
for _, row in df.iterrows():
|
||||
row_id = row['entity_id']
|
||||
if not is_int_string(row_id):
|
||||
continue
|
||||
row_id = int(row_id)
|
||||
desc = row['mention']
|
||||
desc = preprocess_text(desc)
|
||||
row_id = row['entity_id']
|
||||
element = {
|
||||
'text' : desc,
|
||||
'labels': label2id[row_id], # ensure labels starts from 0
|
||||
|
@ -86,7 +96,7 @@ def process_df_to_dict(df):
|
|||
|
||||
def create_dataset():
|
||||
# train
|
||||
data_path = '../../../biomedical_data_import/bc2gm_test.csv'
|
||||
data_path = '../../../../biomedical_data_import/bc2gm_test.csv'
|
||||
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ SHUFFLES=0 # 0 shuffles means it does not re-sample
|
|||
|
||||
# We want to map the entity_id to a consecutive set of id's
|
||||
# import training file
|
||||
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
|
||||
data_path = '../../biomedical_data_import/bc2gm_train.csv'
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = train_df['entity_id'].to_list()
|
||||
|
@ -240,7 +240,7 @@ def process_df_to_dict(df):
|
|||
def create_dataset():
|
||||
# train
|
||||
|
||||
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
|
||||
data_path = '../../biomedical_data_import/bc2gm_train.csv'
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
|
||||
|
||||
|
@ -266,6 +266,7 @@ def train():
|
|||
# model_checkpoint = 'prajjwal1/bert-small'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||
|
||||
# max_length = 120
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
def preprocess_function(example):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
*******************************************************************************
|
||||
Accuracy: 0.15093
|
||||
F1 Score: 0.14063
|
||||
Precision: 0.15594
|
||||
Recall: 0.15093
|
||||
Accuracy: 0.76047
|
||||
F1 Score: 0.78441
|
||||
Precision: 0.85810
|
||||
Recall: 0.76047
|
||||
|
|
|
@ -54,9 +54,9 @@ set_seed(42)
|
|||
|
||||
# %%
|
||||
# PARAMETERS
|
||||
SAMPLES=20
|
||||
SHUFFLES=5
|
||||
AMPLIFY_FACTOR=5
|
||||
SAMPLES=50
|
||||
SHUFFLES=3
|
||||
AMPLIFY_FACTOR=10
|
||||
|
||||
# %%
|
||||
###################################################
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
*******************************************************************************
|
||||
Accuracy: 0.76958
|
||||
F1 Score: 0.79382
|
||||
Precision: 0.88705
|
||||
Recall: 0.76958
|
||||
Accuracy: 0.77614
|
||||
F1 Score: 0.80037
|
||||
Precision: 0.89156
|
||||
Recall: 0.77614
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
*******************************************************************************
|
||||
Accuracy: 0.80689
|
||||
F1 Score: 0.82527
|
||||
Precision: 0.89684
|
||||
Recall: 0.80689
|
||||
Accuracy: 0.80033
|
||||
F1 Score: 0.81484
|
||||
Precision: 0.87456
|
||||
Recall: 0.80033
|
||||
|
|
|
@ -78,7 +78,7 @@ def process_df_to_dict(df):
|
|||
index = row['entity_id']
|
||||
element = {
|
||||
'text' : desc,
|
||||
'label': label2id[index], # ensure labels starts from 0
|
||||
'labels': label2id[index], # ensure labels starts from 0
|
||||
}
|
||||
output_list.append(element)
|
||||
|
||||
|
@ -144,9 +144,7 @@ def test():
|
|||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
max_length=max_length,
|
||||
# truncation=True,
|
||||
padding='max_length'
|
||||
truncation=True,
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
@ -160,7 +158,7 @@ def test():
|
|||
)
|
||||
|
||||
|
||||
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
# datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
|
||||
# %% temp
|
||||
# tokenized_datasets['train'].rename_columns()
|
||||
|
@ -168,7 +166,7 @@ def test():
|
|||
# %%
|
||||
# create data collator
|
||||
|
||||
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length")
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# %%
|
||||
# compute metrics
|
||||
|
@ -197,13 +195,13 @@ def test():
|
|||
actual_labels = []
|
||||
|
||||
|
||||
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False)
|
||||
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, collate_fn=data_collator ,shuffle=False)
|
||||
for batch in tqdm(dataloader):
|
||||
# Inference in batches
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
# save labels too
|
||||
actual_labels.extend(batch['label'])
|
||||
actual_labels.extend(batch['labels'])
|
||||
|
||||
|
||||
# Move to GPU if available
|
||||
|
|
|
@ -306,7 +306,7 @@ def corrupt_string(sentence, corruption_probability=0.01):
|
|||
# each element maps input to output
|
||||
# input: tag_description
|
||||
# output: class label
|
||||
label_flag_list = []
|
||||
# label_flag_list = []
|
||||
|
||||
def process_df_to_dict(df):
|
||||
output_list = []
|
||||
|
@ -331,13 +331,14 @@ def process_df_to_dict(df):
|
|||
for _ in range(10):
|
||||
element = {
|
||||
'text': parent_desc,
|
||||
'label': label2id[index],
|
||||
'labels': label2id[index],
|
||||
}
|
||||
output_list.append(element)
|
||||
|
||||
|
||||
# check if label is in label_flag_list
|
||||
if index not in label_flag_list:
|
||||
# if index not in label_flag_list:
|
||||
if False:
|
||||
|
||||
entity_name = row['entity_name']
|
||||
# add the "entity_name" label as a mention
|
||||
|
@ -452,7 +453,7 @@ def train():
|
|||
model_checkpoint = "distilbert/distilbert-base-uncased"
|
||||
# model_checkpoint = 'google-bert/bert-base-cased'
|
||||
# model_checkpoint = 'prajjwal1/bert-small'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, clean_up_tokenization_spaces=True)
|
||||
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
|
@ -475,6 +476,9 @@ def train():
|
|||
remove_columns="text",
|
||||
)
|
||||
|
||||
# tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
||||
|
||||
|
||||
# %% temp
|
||||
# tokenized_datasets['train'].rename_columns()
|
||||
|
||||
|
@ -525,7 +529,7 @@ def train():
|
|||
per_device_eval_batch_size=64,
|
||||
auto_find_batch_size=False,
|
||||
ddp_find_unused_parameters=False,
|
||||
weight_decay=0.01,
|
||||
weight_decay=0.02,
|
||||
save_total_limit=1,
|
||||
num_train_epochs=40,
|
||||
warmup_steps=400,
|
||||
|
@ -538,7 +542,7 @@ def train():
|
|||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
train_dataset=tokenized_datasets['train'],
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
__pycache__
|
||||
checkpoint
|
|
@ -0,0 +1,218 @@
|
|||
import os, random
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
from data import generate_train_entity_sets
|
||||
|
||||
from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size=32, pin_memory=True, num_workers=8):
|
||||
dataset = generate_train_entity_sets(vocab_entity_id_mentions, entity_id_name=None, group_size=num_sample_per_class, anchor=False)
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
||||
return DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
|
||||
|
||||
def test_dataloader(test_mentions, batch_size=32):
|
||||
return DataLoader(test_mentions, batch_size=batch_size, shuffle=False)
|
||||
|
||||
def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, margin):
|
||||
# DEVICE = torch.device(f"cuda:{dist.get_rank()}")
|
||||
DEVICE = torch.device(f'cuda:{rank}')
|
||||
model.train()
|
||||
epoch_loss, epoch_len = [epoch], [epoch]
|
||||
|
||||
for groups in tqdm(train_dataloader, desc =f'Training batches on {DEVICE}'):
|
||||
groups[0][:] = zip(*groups[0][::-1])
|
||||
x, y = [], []
|
||||
for mention, label in zip(groups[0], groups[1]):
|
||||
mention = [m for m in mention if m != 'PAD']
|
||||
x.extend(mention)
|
||||
y.extend([label.item()]*len(mention))
|
||||
|
||||
optimizer.zero_grad()
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs = inputs.to(DEVICE)
|
||||
cls = model(inputs)
|
||||
# cls = torch.nn.functional.normalize(cls) ## normalize cls embedding before computing loss, didn't work
|
||||
# cls = torch.nn.Dropout(p= 0.25)(cls) ## add dropout, didn't work
|
||||
|
||||
# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True)
|
||||
# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
# if epoch // (epochs / 4) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage
|
||||
# if (epoch // 10) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage
|
||||
loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
||||
|
||||
#### tried circle loss, no obvious advantage
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# logging.info(f'{epoch} {len(x)} {loss.item()}')
|
||||
epoch_loss.append(loss.item())
|
||||
epoch_len.append(len(x))
|
||||
# del inputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
logging.info(f'{DEVICE}{epoch_len}')
|
||||
logging.info(f'{DEVICE}{epoch_loss}')
|
||||
|
||||
def check_label(predicted_cui: str, golden_cui: str) -> int:
|
||||
"""
|
||||
Some composite annotation didn't consider orders
|
||||
So, set label '1' if any cui is matched within composite cui (or single cui)
|
||||
Otherwise, set label '0'
|
||||
"""
|
||||
return int(len(set(predicted_cui.replace('+', '|').split("|")).intersection(set(golden_cui.replace('+', '|').split("|"))))>0)
|
||||
|
||||
def getEmbeddings(mentions, model, tokenizer, DEVICE, batch_size=200):
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
dataloader = DataLoader(mentions, batch_size, shuffle=False)
|
||||
embeddings = np.empty((0, 768), np.float32)
|
||||
with torch.no_grad():
|
||||
for mentions in tqdm(dataloader, desc ='Getting embeddings'):
|
||||
inputs = tokenizer(mentions, padding=True, return_tensors='pt')
|
||||
inputs = inputs.to(DEVICE)
|
||||
cls = model(inputs)
|
||||
embeddings = np.append(embeddings, cls.detach().cpu().numpy(), axis=0)
|
||||
# del inputs, cls
|
||||
# torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, model, tokenizer):
|
||||
DEVICE = torch.device(f'cuda:{rank}')
|
||||
vocab_embeddings = getEmbeddings(vocab_mentions, model, tokenizer, DEVICE)
|
||||
test_embeddings = getEmbeddings(test_mentions, model, tokenizer, DEVICE)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(vocab_embeddings, vocab_ids)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
res = []
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(test_embeddings, n_neighbors=n)
|
||||
num = 0
|
||||
for gold_cui, idx in zip(test_cuis, indices):
|
||||
candidates = [id_to_cui[vocab_ids[i]] for i in idx]
|
||||
for c in candidates:
|
||||
if check_label(c, gold_cui):
|
||||
num += 1
|
||||
break
|
||||
res.append(num / len(test_cuis))
|
||||
# print(f'Top-{n:<3} accuracy: {num / len(test_cuis)}')
|
||||
return res
|
||||
# print(np.min(distances), np.max(distances))
|
||||
|
||||
def save_checkpoint(model, res, epoch, dataName):
|
||||
logging.info(f'Saving model {epoch} {res} ')
|
||||
torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt')
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self,MODEL_NAME):
|
||||
super(Model, self).__init__()
|
||||
self.model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
return cls
|
||||
|
||||
def main(rank, world_size, config):
|
||||
|
||||
print(f"Running main(**args) on rank {rank}.")
|
||||
setup(rank, world_size)
|
||||
|
||||
dataName = config['DEFAULT']['dataName']
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO)
|
||||
|
||||
vocab = defaultdict(set)
|
||||
with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f:
|
||||
for line in f:
|
||||
vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower())
|
||||
|
||||
cui_to_id, id_to_cui = {}, {}
|
||||
vocab_entity_id_mentions = {}
|
||||
for id, cui in enumerate(vocab):
|
||||
cui_to_id[cui] = id
|
||||
id_to_cui[id] = cui
|
||||
for cui, mention in vocab.items():
|
||||
vocab_entity_id_mentions[cui_to_id[cui]] = mention
|
||||
|
||||
vocab_mentions, vocab_ids = [], []
|
||||
for id, mentions in vocab_entity_id_mentions.items():
|
||||
vocab_mentions.extend(mentions)
|
||||
vocab_ids.extend([id]*len(mentions))
|
||||
|
||||
test_mentions, test_cuis = [], []
|
||||
with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
|
||||
for line in f:
|
||||
test_cuis.append(line.strip().split('||')[-1])
|
||||
test_mentions.append(line.strip().split('||')[-2].lower())
|
||||
|
||||
num_sample_per_class = int(config['data']['group_size']) # samples in each group
|
||||
batch_size = int(config['train']['batch_size']) # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = int(config['model']['margin'])
|
||||
epochs = int(config['train']['epochs'])
|
||||
lr = float(config['train']['lr'])
|
||||
MODEL_NAME = config['model']['model_name']
|
||||
|
||||
trainDataLoader = train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size, pin_memory=False, num_workers=0)
|
||||
# test_dataloader = test_dataloader(test_mentions, batch_size=200)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = Model(MODEL_NAME).to(rank)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
||||
|
||||
best = 0
|
||||
if rank == 0:
|
||||
logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' )
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
trainDataLoader.sampler.set_epoch(epoch)
|
||||
train(rank, epoch, epochs, trainDataLoader, ddp_model, optimizer, tokenizer, margin)
|
||||
# if rank == 0 and epoch % 2 == 0:
|
||||
if rank == 0:
|
||||
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
|
||||
logging.info(f'{epoch} {res}')
|
||||
if res[0] > best:
|
||||
best = res[0]
|
||||
save_checkpoint(ddp_model.module, res, epoch, dataName)
|
||||
dist.barrier()
|
||||
cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
import configparser
|
||||
config = configparser.ConfigParser()
|
||||
config.read('config.ini')
|
||||
world_size = torch.cuda.device_count()
|
||||
print(f"You have {world_size} GPUs.")
|
||||
mp.spawn(
|
||||
main,
|
||||
args=(world_size, config),
|
||||
nprocs=world_size,
|
||||
join=True
|
||||
)
|
|
@ -0,0 +1,35 @@
|
|||
import random
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name=None, group_size=10, anchor=False):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
mentions = list(mentions)
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
if entity_id_name:
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
else:
|
||||
group = list(mentions)
|
||||
if len(group) == 1:
|
||||
group.append(group[0])
|
||||
group.extend((group_size-len(group))%group_size * ['PAD'])
|
||||
random.shuffle(group)
|
||||
positives = [(group[i:i + group_size], id) for i in range(0, len(group), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
t[0] = [e for e in t[0] if e != 'PAD']
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
|
@ -0,0 +1,106 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [d['mention'] for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,92 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x, y in batchGenerator(data, batch_size):
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs = inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print(epoch, loss)
|
||||
losses.append(loss)
|
||||
del inputs, outputs, cls, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
|
@ -0,0 +1,242 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
# parallel utilities
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import os
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
# produces a dictionary map from entity_id to entity_name
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
# map from entity_id to list of mentions
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
# map from entity_id to list of entity_names
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
|
||||
# entity sets are always ([list of mentions], id)
|
||||
# to convert it to dataset form, we will just use a dataframe
|
||||
id_mention_pairs = []
|
||||
for entity in entity_sets:
|
||||
entity_id = entity[1]
|
||||
for mention in entity[0]:
|
||||
id_mention_pairs.append({
|
||||
'entity_id': entity_id,
|
||||
'mention': mention
|
||||
})
|
||||
df = pd.DataFrame(id_mention_pairs)
|
||||
return df
|
||||
|
||||
# def batchGenerator(data, batch_size):
|
||||
# for i in range(0, len(data), batch_size):
|
||||
# batch = data[i:i+batch_size]
|
||||
# x, y = [], []
|
||||
# for t in batch:
|
||||
# x.extend(t[0])
|
||||
# y.extend([t[1]]*len(t[0]))
|
||||
# yield x, y
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, df):
|
||||
self.data = df # data should be preprocessed if necessary before being passed here
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Return the data and label as tuples
|
||||
entry = self.data.iloc[idx]
|
||||
x = entry['mention']
|
||||
y = entry['entity_id']
|
||||
return x,y
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' #'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
# model.to(DEVICE)
|
||||
# model.train()
|
||||
#
|
||||
# losses = []
|
||||
|
||||
# %%
|
||||
# for epoch in tqdm(range(epochs)):
|
||||
# data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
# dataset = CustomDataset(data)
|
||||
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
# for x, y in dataloader:
|
||||
# # print(len(x), len(y), end='-->')
|
||||
# optimizer.zero_grad()
|
||||
# inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
# inputs = inputs.to(DEVICE)
|
||||
# outputs = model(**inputs)
|
||||
# cls = outputs.last_hidden_state[:,0,:]
|
||||
# # for training less than half the time, train on easy
|
||||
# if epoch < epochs / 2:
|
||||
# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
||||
# # for training after half the time, train on hard
|
||||
# else:
|
||||
# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
# # print(epoch, loss)
|
||||
# losses.append(loss)
|
||||
# del inputs, outputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
#
|
||||
# torch.save(model.state_dict(), './checkpoint/siamese.pt')
|
||||
# %%
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, path, rank):
|
||||
if rank == 0: # Only save on the master process
|
||||
# Save only the underlying model's state_dict, not the DDP wrapper
|
||||
torch.save({
|
||||
'model_state_dict': model.module.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'epoch': epoch
|
||||
}, path)
|
||||
|
||||
# %%
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def reduce_mean(tensor, nprocs):
|
||||
"""
|
||||
Reduces and averages the tensor across all processes.
|
||||
This function reduces a tensor from all processes to all processes.
|
||||
The resulting tensor is identical in all processes.
|
||||
"""
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= nprocs
|
||||
return rt
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
# Setup model, DataLoader with DistributedSampler
|
||||
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
model = model.cuda(rank)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
# initialize progress bar
|
||||
# Initialize tqdm on the master process
|
||||
if torch.distributed.get_rank() == 0: # Only print from the master process
|
||||
pbar = tqdm(total=epochs, desc='batch progress')
|
||||
|
||||
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
train_dataset = CustomDataset(data)
|
||||
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
||||
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
train_sampler.set_epoch(epoch)
|
||||
for data, targets in train_loader:
|
||||
# data, targets = data.cuda(rank), targets.cuda(rank)
|
||||
optimizer.zero_grad()
|
||||
inputs = tokenizer(data, padding=True, return_tensors='pt')
|
||||
inputs = inputs.to(device)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(targets.to(device), cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(targets.to(device), cls, margin, squared=False)
|
||||
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Reduce and average the loss across all processes
|
||||
reduced_loss = reduce_mean(loss, world_size)
|
||||
total_loss += reduced_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# print(epoch, loss)
|
||||
# losses.append(loss)
|
||||
del inputs, outputs, cls, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
dist.barrier()
|
||||
# Close tqdm bar on master process
|
||||
if torch.distributed.get_rank() == 0: # Only print from the master process
|
||||
pbar.update(epoch)
|
||||
epoch_loss = total_loss / num_batches
|
||||
tqdm.write(f'loss: {epoch_loss}')
|
||||
|
||||
if torch.distributed.get_rank() == 0: # Only print from the master process
|
||||
pbar.close()
|
||||
path = './checkpoint/siamese.pt'
|
||||
torch.save(model.module.state_dict(), path)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Set the number of processes to the number of GPUs available
|
||||
world_size = torch.cuda.device_count()
|
||||
# Use torch.multiprocessing.spawn to launch the processes
|
||||
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
|
@ -0,0 +1,186 @@
|
|||
# stardard functionalities for computing triplet loss, borrow code from
|
||||
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
def _pairwise_distances(embeddings, squared=False):
|
||||
"""Compute the 2D matrix of distances between all the embeddings.
|
||||
Args:
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||
"""
|
||||
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
|
||||
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||
# shape (batch_size,)
|
||||
square_norm = torch.diag(dot_product)
|
||||
|
||||
# Compute the pairwise distance matrix as we have:
|
||||
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||
# shape (batch_size, batch_size)
|
||||
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||
|
||||
# Because of computation errors, some distances might be negative so we put everything >= 0.0
|
||||
distances[distances < 0] = 0
|
||||
|
||||
if not squared:
|
||||
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||
# we need to add a small epsilon where distances == 0.0
|
||||
mask = distances.eq(0).float()
|
||||
distances = distances + mask * 1e-16
|
||||
|
||||
distances = (1.0 -mask) * torch.sqrt(distances)
|
||||
|
||||
return distances
|
||||
|
||||
def _get_triplet_mask(labels):
|
||||
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
||||
A triplet (i, j, k) is valid if:
|
||||
- i, j, k are distinct
|
||||
- labels[i] == labels[j] and labels[i] != labels[k]
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
"""
|
||||
# Check that i, j and k are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
|
||||
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||
|
||||
|
||||
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
i_equal_j = label_equal.unsqueeze(2)
|
||||
i_equal_k = label_equal.unsqueeze(1)
|
||||
|
||||
valid_labels = ~i_equal_k & i_equal_j
|
||||
|
||||
return valid_labels & distinct_indices
|
||||
|
||||
|
||||
def _get_anchor_positive_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check that i and j are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
|
||||
# Check if labels[i] == labels[j]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
|
||||
return labels_equal & indices_not_equal
|
||||
|
||||
|
||||
def _get_anchor_negative_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check if labels[i] != labels[k]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
|
||||
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
||||
|
||||
|
||||
# Cell
|
||||
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist + margin
|
||||
tl = F.relu(tl)
|
||||
triplet_loss = tl.mean()
|
||||
|
||||
return triplet_loss
|
||||
|
||||
# Cell
|
||||
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||
|
||||
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
|
@ -0,0 +1,5 @@
|
|||
Top-1 accuracy: 0.6974169741697417
|
||||
Top-3 accuracy: 0.8126281262812628
|
||||
Top-5 accuracy: 0.8413284132841329
|
||||
Top-10 accuracy: 0.8720787207872078
|
||||
0.005117357 0.74772596
|
|
@ -0,0 +1,5 @@
|
|||
Top-1 accuracy: 0.8019680196801968
|
||||
Top-3 accuracy: 0.8901189011890119
|
||||
Top-5 accuracy: 0.9085690856908569
|
||||
Top-10 accuracy: 0.9249692496924969
|
||||
0.0 0.7323234
|
|
@ -0,0 +1,5 @@
|
|||
Top-1 accuracy: 0.8163181631816319
|
||||
Top-3 accuracy: 0.8987289872898729
|
||||
Top-5 accuracy: 0.9167691676916769
|
||||
Top-10 accuracy: 0.9356293562935629
|
||||
0.0 0.7410505
|
|
@ -0,0 +1,5 @@
|
|||
Top-1 accuracy: 0.7908979089790897
|
||||
Top-3 accuracy: 0.8888888888888888
|
||||
Top-5 accuracy: 0.914309143091431
|
||||
Top-10 accuracy: 0.931119311193112
|
||||
0.0 0.7351225
|
Loading…
Reference in New Issue