added cosines code

This commit is contained in:
Richard Wong 2025-01-18 12:14:06 +09:00
parent b6cf2d4416
commit 94ee7beba7
22 changed files with 978 additions and 57 deletions

View File

@ -23,7 +23,7 @@ for _, row in entity_df.iterrows():
train_df.sort_values(by=['entity_id']).to_markdown('out.md')
# %%
data_path = '../train/class_bert_augmentation/prediction/exports/result.csv'
data_path = '../esAppMod_train/class_bert_augmentation/prediction/exports/result.csv'
prediction_df = pd.read_csv(data_path)
predicted_entity_list = []

View File

@ -25,7 +25,6 @@ from transformers import (
import evaluate
import numpy as np
import pandas as pd
import math
from functools import partial
import warnings
@ -55,14 +54,14 @@ set_seed(42)
# %%
# PARAMETERS
SAMPLES=20
SHUFFLES=5
AMPLIFY_FACTOR=5
SHUFFLES=3
AMPLIFY_FACTOR=3
# %%
###################################################
# import code
# import training file
data_path = '../../esAppMod_data_import/train.csv'
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
@ -192,6 +191,8 @@ def augment_data(df):
return new_df
###############################################################
# regeneration code
# %%
@ -265,13 +266,20 @@ class DynamicDataset(Dataset):
# %%
class RegenerateDatasetCallback(TrainerCallback):
def __init__(self, dataset):
def __init__(self, dataset, every_n_epochs=2):
"""
Args:
dataset: The dataset instance that supports regeneration.
every_n_epochs (int): Number of epochs to wait before regenerating the dataset.
"""
self.dataset = dataset
self.every_n_epochs = every_n_epochs
def on_epoch_begin(self, args, state, control, **kwargs):
print(f"Epoch {int(math.ceil(state.epoch + 1))}: Regenerating dataset")
self.dataset.regenerate_data()
# Check if the current epoch is a multiple of `every_n_epochs`
if (state.epoch + 1) % self.every_n_epochs == 0:
print(f"Epoch {int(state.epoch + 1)}: Regenerating dataset...")
self.dataset.regenerate_data()
# %%
@ -310,11 +318,11 @@ def train():
# Define the callback
lean_df = df.drop(columns=['entity_name'])
dynamic_dataset = DynamicDataset(df = lean_df, sample_size_per_class=10, tokenizer=tokenizer)
# lean_df = df.drop(columns=['entity_name'])
dynamic_dataset = DynamicDataset(df = df, sample_size_per_class=SAMPLES, tokenizer=tokenizer)
# create the regeneration callback
regeneration_callback = RegenerateDatasetCallback(dynamic_dataset)
regeneration_callback = RegenerateDatasetCallback(dynamic_dataset, every_n_epochs=2)
# compute metrics
metric = evaluate.load("accuracy")
@ -346,18 +354,17 @@ def train():
eval_strategy="no",
logging_dir="tensorboard-log",
logging_strategy="epoch",
save_strategy="steps",
save_steps=500,
# save_strategy="epoch",
load_best_model_at_end=False,
learning_rate=5e-5,
per_device_train_batch_size=64,
# per_device_eval_batch_size=64,
learning_rate=1e-4,
per_device_train_batch_size=256,
# per_device_eval_batch_size=256,
auto_find_batch_size=False,
ddp_find_unused_parameters=False,
weight_decay=0.01,
save_total_limit=1,
num_train_epochs=120,
warmup_steps=400,
num_train_epochs=80,
warmup_steps=200,
bf16=True,
push_to_hub=False,
remove_unused_columns=False,

View File

@ -1,6 +1,6 @@
*******************************************************************************
Accuracy: 0.80655
F1 Score: 0.82821
Precision: 0.87847
Recall: 0.80655
Accuracy: 0.77215
F1 Score: 0.79997
Precision: 0.87183
Recall: 0.77215

View File

@ -33,7 +33,7 @@ BATCH_SIZE = 32
# %%
# construct the target id list
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
data_path = '../../../../biomedical_data_import/bc2gm_train.csv'
train_df = pd.read_csv(data_path, skipinitialspace=True)
entity_ids = train_df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
@ -62,6 +62,13 @@ def preprocess_text(text):
return text
def is_int_string(s):
try:
int(s)
return True
except ValueError:
return False
# outputs a list of dictionaries
@ -72,9 +79,12 @@ def preprocess_text(text):
def process_df_to_dict(df):
output_list = []
for _, row in df.iterrows():
row_id = row['entity_id']
if not is_int_string(row_id):
continue
row_id = int(row_id)
desc = row['mention']
desc = preprocess_text(desc)
row_id = row['entity_id']
element = {
'text' : desc,
'labels': label2id[row_id], # ensure labels starts from 0
@ -86,7 +96,7 @@ def process_df_to_dict(df):
def create_dataset():
# train
data_path = '../../../biomedical_data_import/bc2gm_test.csv'
data_path = '../../../../biomedical_data_import/bc2gm_test.csv'
test_df = pd.read_csv(data_path, skipinitialspace=True)

View File

@ -51,7 +51,7 @@ SHUFFLES=0 # 0 shuffles means it does not re-sample
# We want to map the entity_id to a consecutive set of id's
# import training file
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
data_path = '../../biomedical_data_import/bc2gm_train.csv'
train_df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = train_df['entity_id'].to_list()
@ -240,7 +240,7 @@ def process_df_to_dict(df):
def create_dataset():
# train
data_path = '../../../biomedical_data_import/bc2gm_train.csv'
data_path = '../../biomedical_data_import/bc2gm_train.csv'
train_df = pd.read_csv(data_path, skipinitialspace=True)
@ -266,6 +266,7 @@ def train():
# model_checkpoint = 'prajjwal1/bert-small'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True)
# max_length = 120
# given a dataset entry, run it through the tokenizer
def preprocess_function(example):

View File

@ -1,6 +1,6 @@
*******************************************************************************
Accuracy: 0.15093
F1 Score: 0.14063
Precision: 0.15594
Recall: 0.15093
Accuracy: 0.76047
F1 Score: 0.78441
Precision: 0.85810
Recall: 0.76047

View File

@ -54,9 +54,9 @@ set_seed(42)
# %%
# PARAMETERS
SAMPLES=20
SHUFFLES=5
AMPLIFY_FACTOR=5
SAMPLES=50
SHUFFLES=3
AMPLIFY_FACTOR=10
# %%
###################################################

View File

@ -1,6 +1,6 @@
*******************************************************************************
Accuracy: 0.76958
F1 Score: 0.79382
Precision: 0.88705
Recall: 0.76958
Accuracy: 0.77614
F1 Score: 0.80037
Precision: 0.89156
Recall: 0.77614

View File

@ -1,6 +1,6 @@
*******************************************************************************
Accuracy: 0.80689
F1 Score: 0.82527
Precision: 0.89684
Recall: 0.80689
Accuracy: 0.80033
F1 Score: 0.81484
Precision: 0.87456
Recall: 0.80033

View File

@ -78,7 +78,7 @@ def process_df_to_dict(df):
index = row['entity_id']
element = {
'text' : desc,
'label': label2id[index], # ensure labels starts from 0
'labels': label2id[index], # ensure labels starts from 0
}
output_list.append(element)
@ -144,9 +144,7 @@ def test():
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
max_length=max_length,
# truncation=True,
padding='max_length'
truncation=True,
)
return model_inputs
@ -160,7 +158,7 @@ def test():
)
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
# datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
# %% temp
# tokenized_datasets['train'].rename_columns()
@ -168,7 +166,7 @@ def test():
# %%
# create data collator
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# %%
# compute metrics
@ -197,13 +195,13 @@ def test():
actual_labels = []
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False)
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, collate_fn=data_collator ,shuffle=False)
for batch in tqdm(dataloader):
# Inference in batches
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
# save labels too
actual_labels.extend(batch['label'])
actual_labels.extend(batch['labels'])
# Move to GPU if available

View File

@ -306,7 +306,7 @@ def corrupt_string(sentence, corruption_probability=0.01):
# each element maps input to output
# input: tag_description
# output: class label
label_flag_list = []
# label_flag_list = []
def process_df_to_dict(df):
output_list = []
@ -331,13 +331,14 @@ def process_df_to_dict(df):
for _ in range(10):
element = {
'text': parent_desc,
'label': label2id[index],
'labels': label2id[index],
}
output_list.append(element)
# check if label is in label_flag_list
if index not in label_flag_list:
# if index not in label_flag_list:
if False:
entity_name = row['entity_name']
# add the "entity_name" label as a mention
@ -452,7 +453,7 @@ def train():
model_checkpoint = "distilbert/distilbert-base-uncased"
# model_checkpoint = 'google-bert/bert-base-cased'
# model_checkpoint = 'prajjwal1/bert-small'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, clean_up_tokenization_spaces=True)
# given a dataset entry, run it through the tokenizer
@ -475,6 +476,9 @@ def train():
remove_columns="text",
)
# tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# %% temp
# tokenized_datasets['train'].rename_columns()
@ -525,7 +529,7 @@ def train():
per_device_eval_batch_size=64,
auto_find_batch_size=False,
ddp_find_unused_parameters=False,
weight_decay=0.01,
weight_decay=0.02,
save_total_limit=1,
num_train_epochs=40,
warmup_steps=400,
@ -538,7 +542,7 @@ def train():
trainer = Trainer(
model,
training_args,
train_dataset=tokenized_datasets["train"],
train_dataset=tokenized_datasets['train'],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,

2
tackle_container/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__
checkpoint

View File

@ -0,0 +1,218 @@
import os, random
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from transformers import AutoTokenizer, AutoModel
from data import generate_train_entity_sets
from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import logging
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size=32, pin_memory=True, num_workers=8):
dataset = generate_train_entity_sets(vocab_entity_id_mentions, entity_id_name=None, group_size=num_sample_per_class, anchor=False)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
return DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
def test_dataloader(test_mentions, batch_size=32):
return DataLoader(test_mentions, batch_size=batch_size, shuffle=False)
def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, margin):
# DEVICE = torch.device(f"cuda:{dist.get_rank()}")
DEVICE = torch.device(f'cuda:{rank}')
model.train()
epoch_loss, epoch_len = [epoch], [epoch]
for groups in tqdm(train_dataloader, desc =f'Training batches on {DEVICE}'):
groups[0][:] = zip(*groups[0][::-1])
x, y = [], []
for mention, label in zip(groups[0], groups[1]):
mention = [m for m in mention if m != 'PAD']
x.extend(mention)
y.extend([label.item()]*len(mention))
optimizer.zero_grad()
inputs = tokenizer(x, padding=True, return_tensors='pt')
inputs = inputs.to(DEVICE)
cls = model(inputs)
# cls = torch.nn.functional.normalize(cls) ## normalize cls embedding before computing loss, didn't work
# cls = torch.nn.Dropout(p= 0.25)(cls) ## add dropout, didn't work
# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True)
# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True)
if epoch < epochs / 2:
# if epoch // (epochs / 4) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage
# if (epoch // 10) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage
loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
else:
loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
#### tried circle loss, no obvious advantage
loss.backward()
optimizer.step()
# logging.info(f'{epoch} {len(x)} {loss.item()}')
epoch_loss.append(loss.item())
epoch_len.append(len(x))
# del inputs, cls, loss
# torch.cuda.empty_cache()
logging.info(f'{DEVICE}{epoch_len}')
logging.info(f'{DEVICE}{epoch_loss}')
def check_label(predicted_cui: str, golden_cui: str) -> int:
"""
Some composite annotation didn't consider orders
So, set label '1' if any cui is matched within composite cui (or single cui)
Otherwise, set label '0'
"""
return int(len(set(predicted_cui.replace('+', '|').split("|")).intersection(set(golden_cui.replace('+', '|').split("|"))))>0)
def getEmbeddings(mentions, model, tokenizer, DEVICE, batch_size=200):
model.to(DEVICE)
model.eval()
dataloader = DataLoader(mentions, batch_size, shuffle=False)
embeddings = np.empty((0, 768), np.float32)
with torch.no_grad():
for mentions in tqdm(dataloader, desc ='Getting embeddings'):
inputs = tokenizer(mentions, padding=True, return_tensors='pt')
inputs = inputs.to(DEVICE)
cls = model(inputs)
embeddings = np.append(embeddings, cls.detach().cpu().numpy(), axis=0)
# del inputs, cls
# torch.cuda.empty_cache()
return embeddings
def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, model, tokenizer):
DEVICE = torch.device(f'cuda:{rank}')
vocab_embeddings = getEmbeddings(vocab_mentions, model, tokenizer, DEVICE)
test_embeddings = getEmbeddings(test_mentions, model, tokenizer, DEVICE)
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(vocab_embeddings, vocab_ids)
n_neighbors = [1, 3, 5, 10]
res = []
for n in n_neighbors:
distances, indices = knn.kneighbors(test_embeddings, n_neighbors=n)
num = 0
for gold_cui, idx in zip(test_cuis, indices):
candidates = [id_to_cui[vocab_ids[i]] for i in idx]
for c in candidates:
if check_label(c, gold_cui):
num += 1
break
res.append(num / len(test_cuis))
# print(f'Top-{n:<3} accuracy: {num / len(test_cuis)}')
return res
# print(np.min(distances), np.max(distances))
def save_checkpoint(model, res, epoch, dataName):
logging.info(f'Saving model {epoch} {res} ')
torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt')
class Model(nn.Module):
def __init__(self,MODEL_NAME):
super(Model, self).__init__()
self.model = AutoModel.from_pretrained(MODEL_NAME)
def forward(self, inputs):
outputs = self.model(**inputs)
cls = outputs.last_hidden_state[:,0,:]
return cls
def main(rank, world_size, config):
print(f"Running main(**args) on rank {rank}.")
setup(rank, world_size)
dataName = config['DEFAULT']['dataName']
logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO)
vocab = defaultdict(set)
with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f:
for line in f:
vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower())
cui_to_id, id_to_cui = {}, {}
vocab_entity_id_mentions = {}
for id, cui in enumerate(vocab):
cui_to_id[cui] = id
id_to_cui[id] = cui
for cui, mention in vocab.items():
vocab_entity_id_mentions[cui_to_id[cui]] = mention
vocab_mentions, vocab_ids = [], []
for id, mentions in vocab_entity_id_mentions.items():
vocab_mentions.extend(mentions)
vocab_ids.extend([id]*len(mentions))
test_mentions, test_cuis = [], []
with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
for line in f:
test_cuis.append(line.strip().split('||')[-1])
test_mentions.append(line.strip().split('||')[-2].lower())
num_sample_per_class = int(config['data']['group_size']) # samples in each group
batch_size = int(config['train']['batch_size']) # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = int(config['model']['margin'])
epochs = int(config['train']['epochs'])
lr = float(config['train']['lr'])
MODEL_NAME = config['model']['model_name']
trainDataLoader = train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size, pin_memory=False, num_workers=0)
# test_dataloader = test_dataloader(test_mentions, batch_size=200)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = Model(MODEL_NAME).to(rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
best = 0
if rank == 0:
logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' )
for epoch in tqdm(range(epochs)):
trainDataLoader.sampler.set_epoch(epoch)
train(rank, epoch, epochs, trainDataLoader, ddp_model, optimizer, tokenizer, margin)
# if rank == 0 and epoch % 2 == 0:
if rank == 0:
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
logging.info(f'{epoch} {res}')
if res[0] > best:
best = res[0]
save_checkpoint(ddp_model.module, res, epoch, dataName)
dist.barrier()
cleanup()
if __name__ == '__main__':
import configparser
config = configparser.ConfigParser()
config.read('config.ini')
world_size = torch.cuda.device_count()
print(f"You have {world_size} GPUs.")
mp.spawn(
main,
args=(world_size, config),
nprocs=world_size,
join=True
)

35
tackle_container/data.py Normal file
View File

@ -0,0 +1,35 @@
import random
def generate_train_entity_sets(entity_id_mentions, entity_id_name=None, group_size=10, anchor=False):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
mentions = list(mentions)
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
if entity_id_name:
group = list(set([entity_id_name[id]] + mentions))
else:
group = list(mentions)
if len(group) == 1:
group.append(group[0])
group.extend((group_size-len(group))%group_size * ['PAD'])
random.shuffle(group)
positives = [(group[i:i + group_size], id) for i in range(0, len(group), group_size)]
entity_sets.extend(positives)
return entity_sets
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
t[0] = [e for e in t[0] if e != 'PAD']
x.extend(t[0])
y.extend([t[1]]*len(t[0]))
yield x, y

View File

@ -0,0 +1,106 @@
# %%
import torch
import json
import random
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import gc
# %%
# Step 2: Load the state dictionary
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# state_dict = torch.load('./checkpoint/siamese.pt')
state_dict = torch.load('./checkpoint/siamese_simple.pt')
# Step 3: Apply the state dictionary to the model
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
with open('../esAppMod/infer.json', 'r') as file:
test = json.load(file)
x_test = [d['mention'] for _, d in test['data'].items()]
y_test = [d['entity_id'] for _, d in test['data'].items()]
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(train_entities, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
output = outputs.last_hidden_state[:,0,:]
output = output.detach().cpu().numpy()
embedding_list.append(output)
cls = np.concatenate(embedding_list)
# %%
gc.collect()
torch.cuda.empty_cache()
# %%
batches = batch_list(x_test, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
output = outputs.last_hidden_state[:,0,:]
output = output.detach().cpu().numpy()
embedding_list.append(output)
cls_test = np.concatenate(embedding_list)
# %%
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
n_neighbors = [1, 3, 5, 10]
with open("results/output.txt", "w") as f:
for n in n_neighbors:
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
num = 0
for a,b in zip(y_test, indices):
b = [labels[i] for i in b]
if a in b:
num += 1
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
print(np.min(distances), np.max(distances), file=f)
# %%

View File

@ -0,0 +1,92 @@
# %%
import torch
import json
import random
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
return entity_sets
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
x.extend(t[0])
y.extend([t[1]]*len(t[0]))
yield x, y
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
num_sample_per_class = 10 # samples in each group
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model.to(DEVICE)
model.train()
losses = []
for epoch in tqdm(range(epochs)):
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
random.shuffle(data)
for x, y in batchGenerator(data, batch_size):
# print(len(x), len(y), end='-->')
optimizer.zero_grad()
inputs = tokenizer(x, padding=True, return_tensors='pt')
inputs = inputs.to(DEVICE)
outputs = model(**inputs)
cls = outputs.last_hidden_state[:,0,:]
# for training less than half the time, train on easy
if epoch < epochs / 2:
loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
# for training after half the time, train on hard
else:
loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
loss.backward()
optimizer.step()
# print(epoch, loss)
losses.append(loss)
del inputs, outputs, cls, loss
torch.cuda.empty_cache()
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')

View File

@ -0,0 +1,242 @@
# %%
import torch
import json
import random
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
# parallel utilities
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import os
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
# produces a dictionary map from entity_id to entity_name
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
# map from entity_id to list of mentions
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
# map from entity_id to list of entity_names
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
# entity sets are always ([list of mentions], id)
# to convert it to dataset form, we will just use a dataframe
id_mention_pairs = []
for entity in entity_sets:
entity_id = entity[1]
for mention in entity[0]:
id_mention_pairs.append({
'entity_id': entity_id,
'mention': mention
})
df = pd.DataFrame(id_mention_pairs)
return df
# def batchGenerator(data, batch_size):
# for i in range(0, len(data), batch_size):
# batch = data[i:i+batch_size]
# x, y = [], []
# for t in batch:
# x.extend(t[0])
# y.extend([t[1]]*len(t[0]))
# yield x, y
class CustomDataset(Dataset):
def __init__(self, df):
self.data = df # data should be preprocessed if necessary before being passed here
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# Return the data and label as tuples
entry = self.data.iloc[idx]
x = entry['mention']
y = entry['entity_id']
return x,y
# %%
num_sample_per_class = 10 # samples in each group
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' #'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# model.to(DEVICE)
# model.train()
#
# losses = []
# %%
# for epoch in tqdm(range(epochs)):
# data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
# dataset = CustomDataset(data)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# for x, y in dataloader:
# # print(len(x), len(y), end='-->')
# optimizer.zero_grad()
# inputs = tokenizer(x, padding=True, return_tensors='pt')
# inputs = inputs.to(DEVICE)
# outputs = model(**inputs)
# cls = outputs.last_hidden_state[:,0,:]
# # for training less than half the time, train on easy
# if epoch < epochs / 2:
# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
# # for training after half the time, train on hard
# else:
# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False)
# loss.backward()
# optimizer.step()
# # print(epoch, loss)
# losses.append(loss)
# del inputs, outputs, cls, loss
# torch.cuda.empty_cache()
#
# torch.save(model.state_dict(), './checkpoint/siamese.pt')
# %%
def save_checkpoint(model, optimizer, epoch, path, rank):
if rank == 0: # Only save on the master process
# Save only the underlying model's state_dict, not the DDP wrapper
torch.save({
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}, path)
# %%
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def reduce_mean(tensor, nprocs):
"""
Reduces and averages the tensor across all processes.
This function reduces a tensor from all processes to all processes.
The resulting tensor is identical in all processes.
"""
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= nprocs
return rt
def train(rank, world_size):
setup(rank, world_size)
# Setup model, DataLoader with DistributedSampler
model = AutoModel.from_pretrained(MODEL_NAME)
model = model.cuda(rank)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# initialize progress bar
# Initialize tqdm on the master process
if torch.distributed.get_rank() == 0: # Only print from the master process
pbar = tqdm(total=epochs, desc='batch progress')
for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
train_dataset = CustomDataset(data)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
device = torch.device(f"cuda:{rank}")
train_sampler.set_epoch(epoch)
for data, targets in train_loader:
# data, targets = data.cuda(rank), targets.cuda(rank)
optimizer.zero_grad()
inputs = tokenizer(data, padding=True, return_tensors='pt')
inputs = inputs.to(device)
outputs = model(**inputs)
cls = outputs.last_hidden_state[:,0,:]
# for training less than half the time, train on easy
if epoch < epochs / 2:
loss, _ = batch_all_triplet_loss(targets.to(device), cls, margin, squared=False)
# for training after half the time, train on hard
else:
loss = batch_hard_triplet_loss(targets.to(device), cls, margin, squared=False)
loss.backward()
optimizer.step()
# Reduce and average the loss across all processes
reduced_loss = reduce_mean(loss, world_size)
total_loss += reduced_loss.item()
num_batches += 1
# print(epoch, loss)
# losses.append(loss)
del inputs, outputs, cls, loss
torch.cuda.empty_cache()
dist.barrier()
# Close tqdm bar on master process
if torch.distributed.get_rank() == 0: # Only print from the master process
pbar.update(epoch)
epoch_loss = total_loss / num_batches
tqdm.write(f'loss: {epoch_loss}')
if torch.distributed.get_rank() == 0: # Only print from the master process
pbar.close()
path = './checkpoint/siamese.pt'
torch.save(model.module.state_dict(), path)
cleanup()
if __name__ == '__main__':
# Set the number of processes to the number of GPUs available
world_size = torch.cuda.device_count()
# Use torch.multiprocessing.spawn to launch the processes
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

186
tackle_container/loss.py Normal file
View File

@ -0,0 +1,186 @@
# stardard functionalities for computing triplet loss, borrow code from
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
import torch
import torch.nn.functional as F
def _pairwise_distances(embeddings, squared=False):
"""Compute the 2D matrix of distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
pairwise_distances: tensor of shape (batch_size, batch_size)
"""
dot_product = torch.matmul(embeddings, embeddings.t())
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
# shape (batch_size,)
square_norm = torch.diag(dot_product)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
# Because of computation errors, some distances might be negative so we put everything >= 0.0
distances[distances < 0] = 0
if not squared:
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = distances.eq(0).float()
distances = distances + mask * 1e-16
distances = (1.0 -mask) * torch.sqrt(distances)
return distances
def _get_triplet_mask(labels):
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
A triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
"""
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
j_not_equal_k = indices_not_equal.unsqueeze(0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
i_equal_j = label_equal.unsqueeze(2)
i_equal_k = label_equal.unsqueeze(1)
valid_labels = ~i_equal_k & i_equal_j
return valid_labels & distinct_indices
def _get_anchor_positive_triplet_mask(labels):
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(labels):
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check if labels[i] != labels[k]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
# Cell
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
"""Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid positive (they should have same label)
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# shape (batch_size, 1)
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative (they should have different labels)
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
tl = hardest_positive_dist - hardest_negative_dist + margin
tl = F.relu(tl)
triplet_loss = tl.mean()
return triplet_loss
# Cell
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
"""Build the triplet loss over a batch of embeddings.
We generate all the valid triplets and average the loss over the positive ones.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
anchor_positive_dist = pairwise_dist.unsqueeze(2)
anchor_negative_dist = pairwise_dist.unsqueeze(1)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
# Put to zero the invalid triplets
# (where label(a) != label(p) or label(n) == label(a) or a == p)
mask = _get_triplet_mask(labels)
triplet_loss = mask.float() * triplet_loss
# Remove negative losses (i.e. the easy triplets)
triplet_loss = F.relu(triplet_loss)
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = triplet_loss[triplet_loss > 1e-16]
num_positive_triplets = valid_triplets.size(0)
num_valid_triplets = mask.sum()
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
return triplet_loss, fraction_positive_triplets

View File

@ -0,0 +1,5 @@
Top-1 accuracy: 0.6974169741697417
Top-3 accuracy: 0.8126281262812628
Top-5 accuracy: 0.8413284132841329
Top-10 accuracy: 0.8720787207872078
0.005117357 0.74772596

View File

@ -0,0 +1,5 @@
Top-1 accuracy: 0.8019680196801968
Top-3 accuracy: 0.8901189011890119
Top-5 accuracy: 0.9085690856908569
Top-10 accuracy: 0.9249692496924969
0.0 0.7323234

View File

@ -0,0 +1,5 @@
Top-1 accuracy: 0.8163181631816319
Top-3 accuracy: 0.8987289872898729
Top-5 accuracy: 0.9167691676916769
Top-10 accuracy: 0.9356293562935629
0.0 0.7410505

View File

@ -0,0 +1,5 @@
Top-1 accuracy: 0.7908979089790897
Top-3 accuracy: 0.8888888888888888
Top-5 accuracy: 0.914309143091431
Top-10 accuracy: 0.931119311193112
0.0 0.7351225