diff --git a/analysis/error_analysis_esAppMod.py b/analysis/error_analysis_esAppMod.py index 0df5bb4..f445df8 100644 --- a/analysis/error_analysis_esAppMod.py +++ b/analysis/error_analysis_esAppMod.py @@ -23,7 +23,7 @@ for _, row in entity_df.iterrows(): train_df.sort_values(by=['entity_id']).to_markdown('out.md') # %% -data_path = '../train/class_bert_augmentation/prediction/exports/result.csv' +data_path = '../esAppMod_train/class_bert_augmentation/prediction/exports/result.csv' prediction_df = pd.read_csv(data_path) predicted_entity_list = [] diff --git a/biomedical_train/bc2gm/augmentation/dynamic_train.py b/biomedical_train/bc2gm/augmentation/dynamic_train.py index b5d5b24..ec2bb89 100644 --- a/biomedical_train/bc2gm/augmentation/dynamic_train.py +++ b/biomedical_train/bc2gm/augmentation/dynamic_train.py @@ -25,7 +25,6 @@ from transformers import ( import evaluate import numpy as np import pandas as pd -import math from functools import partial import warnings @@ -55,14 +54,14 @@ set_seed(42) # %% # PARAMETERS SAMPLES=20 -SHUFFLES=5 -AMPLIFY_FACTOR=5 +SHUFFLES=3 +AMPLIFY_FACTOR=3 # %% ################################################### # import code # import training file -data_path = '../../esAppMod_data_import/train.csv' +data_path = '../../../biomedical_data_import/bc2gm_train.csv' df = pd.read_csv(data_path, skipinitialspace=True) # rather than use pattern, we use the real thing and property entity_ids = df['entity_id'].to_list() @@ -192,6 +191,8 @@ def augment_data(df): return new_df + + ############################################################### # regeneration code # %% @@ -265,13 +266,20 @@ class DynamicDataset(Dataset): # %% class RegenerateDatasetCallback(TrainerCallback): - def __init__(self, dataset): + def __init__(self, dataset, every_n_epochs=2): + """ + Args: + dataset: The dataset instance that supports regeneration. + every_n_epochs (int): Number of epochs to wait before regenerating the dataset. + """ self.dataset = dataset + self.every_n_epochs = every_n_epochs def on_epoch_begin(self, args, state, control, **kwargs): - print(f"Epoch {int(math.ceil(state.epoch + 1))}: Regenerating dataset") - self.dataset.regenerate_data() - + # Check if the current epoch is a multiple of `every_n_epochs` + if (state.epoch + 1) % self.every_n_epochs == 0: + print(f"Epoch {int(state.epoch + 1)}: Regenerating dataset...") + self.dataset.regenerate_data() # %% @@ -310,11 +318,11 @@ def train(): # Define the callback - lean_df = df.drop(columns=['entity_name']) - dynamic_dataset = DynamicDataset(df = lean_df, sample_size_per_class=10, tokenizer=tokenizer) + # lean_df = df.drop(columns=['entity_name']) + dynamic_dataset = DynamicDataset(df = df, sample_size_per_class=SAMPLES, tokenizer=tokenizer) # create the regeneration callback - regeneration_callback = RegenerateDatasetCallback(dynamic_dataset) + regeneration_callback = RegenerateDatasetCallback(dynamic_dataset, every_n_epochs=2) # compute metrics metric = evaluate.load("accuracy") @@ -346,18 +354,17 @@ def train(): eval_strategy="no", logging_dir="tensorboard-log", logging_strategy="epoch", - save_strategy="steps", - save_steps=500, + # save_strategy="epoch", load_best_model_at_end=False, - learning_rate=5e-5, - per_device_train_batch_size=64, - # per_device_eval_batch_size=64, + learning_rate=1e-4, + per_device_train_batch_size=256, + # per_device_eval_batch_size=256, auto_find_batch_size=False, ddp_find_unused_parameters=False, weight_decay=0.01, save_total_limit=1, - num_train_epochs=120, - warmup_steps=400, + num_train_epochs=80, + warmup_steps=200, bf16=True, push_to_hub=False, remove_unused_columns=False, diff --git a/biomedical_train/bc2gm/augmentation/prediction/output.txt b/biomedical_train/bc2gm/augmentation/prediction/output.txt index 93e7dd1..3dd4296 100644 --- a/biomedical_train/bc2gm/augmentation/prediction/output.txt +++ b/biomedical_train/bc2gm/augmentation/prediction/output.txt @@ -1,6 +1,6 @@ ******************************************************************************* -Accuracy: 0.80655 -F1 Score: 0.82821 -Precision: 0.87847 -Recall: 0.80655 +Accuracy: 0.77215 +F1 Score: 0.79997 +Precision: 0.87183 +Recall: 0.77215 diff --git a/biomedical_train/bc2gm/augmentation/prediction/predict.py b/biomedical_train/bc2gm/augmentation/prediction/predict.py index a7e1b62..19cdb35 100644 --- a/biomedical_train/bc2gm/augmentation/prediction/predict.py +++ b/biomedical_train/bc2gm/augmentation/prediction/predict.py @@ -33,7 +33,7 @@ BATCH_SIZE = 32 # %% # construct the target id list -data_path = '../../../biomedical_data_import/bc2gm_train.csv' +data_path = '../../../../biomedical_data_import/bc2gm_train.csv' train_df = pd.read_csv(data_path, skipinitialspace=True) entity_ids = train_df['entity_id'].to_list() target_id_list = sorted(list(set(entity_ids))) @@ -62,6 +62,13 @@ def preprocess_text(text): return text +def is_int_string(s): + try: + int(s) + return True + except ValueError: + return False + # outputs a list of dictionaries @@ -72,9 +79,12 @@ def preprocess_text(text): def process_df_to_dict(df): output_list = [] for _, row in df.iterrows(): + row_id = row['entity_id'] + if not is_int_string(row_id): + continue + row_id = int(row_id) desc = row['mention'] desc = preprocess_text(desc) - row_id = row['entity_id'] element = { 'text' : desc, 'labels': label2id[row_id], # ensure labels starts from 0 @@ -86,7 +96,7 @@ def process_df_to_dict(df): def create_dataset(): # train - data_path = '../../../biomedical_data_import/bc2gm_test.csv' + data_path = '../../../../biomedical_data_import/bc2gm_test.csv' test_df = pd.read_csv(data_path, skipinitialspace=True) diff --git a/biomedical_train/bc2gm/augmentation/train.py b/biomedical_train/bc2gm/augmentation/train.py index 1e413f5..614de64 100644 --- a/biomedical_train/bc2gm/augmentation/train.py +++ b/biomedical_train/bc2gm/augmentation/train.py @@ -51,7 +51,7 @@ SHUFFLES=0 # 0 shuffles means it does not re-sample # We want to map the entity_id to a consecutive set of id's # import training file -data_path = '../../../biomedical_data_import/bc2gm_train.csv' +data_path = '../../biomedical_data_import/bc2gm_train.csv' train_df = pd.read_csv(data_path, skipinitialspace=True) # rather than use pattern, we use the real thing and property entity_ids = train_df['entity_id'].to_list() @@ -240,7 +240,7 @@ def process_df_to_dict(df): def create_dataset(): # train - data_path = '../../../biomedical_data_import/bc2gm_train.csv' + data_path = '../../biomedical_data_import/bc2gm_train.csv' train_df = pd.read_csv(data_path, skipinitialspace=True) @@ -266,6 +266,7 @@ def train(): # model_checkpoint = 'prajjwal1/bert-small' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True) + # max_length = 120 # given a dataset entry, run it through the tokenizer def preprocess_function(example): diff --git a/biomedical_train/bc2gm/simple/prediction/output.txt b/biomedical_train/bc2gm/simple/prediction/output.txt index 7811f07..e684dc3 100644 --- a/biomedical_train/bc2gm/simple/prediction/output.txt +++ b/biomedical_train/bc2gm/simple/prediction/output.txt @@ -1,6 +1,6 @@ ******************************************************************************* -Accuracy: 0.15093 -F1 Score: 0.14063 -Precision: 0.15594 -Recall: 0.15093 +Accuracy: 0.76047 +F1 Score: 0.78441 +Precision: 0.85810 +Recall: 0.76047 diff --git a/esAppMod_train/augmentation/dynamic_train.py b/esAppMod_train/augmentation/dynamic_train.py index a75d98c..9d21be0 100644 --- a/esAppMod_train/augmentation/dynamic_train.py +++ b/esAppMod_train/augmentation/dynamic_train.py @@ -54,9 +54,9 @@ set_seed(42) # %% # PARAMETERS -SAMPLES=20 -SHUFFLES=5 -AMPLIFY_FACTOR=5 +SAMPLES=50 +SHUFFLES=3 +AMPLIFY_FACTOR=10 # %% ################################################### diff --git a/esAppMod_train/augmentation/prediction/output.txt b/esAppMod_train/augmentation/prediction/output.txt index 5b098e0..759db7e 100644 --- a/esAppMod_train/augmentation/prediction/output.txt +++ b/esAppMod_train/augmentation/prediction/output.txt @@ -1,6 +1,6 @@ ******************************************************************************* -Accuracy: 0.76958 -F1 Score: 0.79382 -Precision: 0.88705 -Recall: 0.76958 +Accuracy: 0.77614 +F1 Score: 0.80037 +Precision: 0.89156 +Recall: 0.77614 diff --git a/esAppMod_train/class_bert_augmentation/prediction/output.txt b/esAppMod_train/class_bert_augmentation/prediction/output.txt index d13147d..49e225a 100644 --- a/esAppMod_train/class_bert_augmentation/prediction/output.txt +++ b/esAppMod_train/class_bert_augmentation/prediction/output.txt @@ -1,6 +1,6 @@ ******************************************************************************* -Accuracy: 0.80689 -F1 Score: 0.82527 -Precision: 0.89684 -Recall: 0.80689 +Accuracy: 0.80033 +F1 Score: 0.81484 +Precision: 0.87456 +Recall: 0.80033 diff --git a/esAppMod_train/class_bert_augmentation/prediction/predict.py b/esAppMod_train/class_bert_augmentation/prediction/predict.py index 12b1954..1d7562a 100644 --- a/esAppMod_train/class_bert_augmentation/prediction/predict.py +++ b/esAppMod_train/class_bert_augmentation/prediction/predict.py @@ -78,7 +78,7 @@ def process_df_to_dict(df): index = row['entity_id'] element = { 'text' : desc, - 'label': label2id[index], # ensure labels starts from 0 + 'labels': label2id[index], # ensure labels starts from 0 } output_list.append(element) @@ -144,9 +144,7 @@ def test(): # there is no need to create a separate 'labels' model_inputs = tokenizer( input, - max_length=max_length, - # truncation=True, - padding='max_length' + truncation=True, ) return model_inputs @@ -160,7 +158,7 @@ def test(): ) - datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) + # datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) # %% temp # tokenized_datasets['train'].rename_columns() @@ -168,7 +166,7 @@ def test(): # %% # create data collator - # data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length") + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # %% # compute metrics @@ -197,13 +195,13 @@ def test(): actual_labels = [] - dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False) + dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, collate_fn=data_collator ,shuffle=False) for batch in tqdm(dataloader): # Inference in batches input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] # save labels too - actual_labels.extend(batch['label']) + actual_labels.extend(batch['labels']) # Move to GPU if available diff --git a/esAppMod_train/class_bert_augmentation/train.py b/esAppMod_train/class_bert_augmentation/train.py index 4344bf6..e088cd8 100644 --- a/esAppMod_train/class_bert_augmentation/train.py +++ b/esAppMod_train/class_bert_augmentation/train.py @@ -306,7 +306,7 @@ def corrupt_string(sentence, corruption_probability=0.01): # each element maps input to output # input: tag_description # output: class label -label_flag_list = [] +# label_flag_list = [] def process_df_to_dict(df): output_list = [] @@ -331,13 +331,14 @@ def process_df_to_dict(df): for _ in range(10): element = { 'text': parent_desc, - 'label': label2id[index], + 'labels': label2id[index], } output_list.append(element) # check if label is in label_flag_list - if index not in label_flag_list: + # if index not in label_flag_list: + if False: entity_name = row['entity_name'] # add the "entity_name" label as a mention @@ -452,7 +453,7 @@ def train(): model_checkpoint = "distilbert/distilbert-base-uncased" # model_checkpoint = 'google-bert/bert-base-cased' # model_checkpoint = 'prajjwal1/bert-small' - tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt", clean_up_tokenization_spaces=True) + tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, clean_up_tokenization_spaces=True) # given a dataset entry, run it through the tokenizer @@ -475,6 +476,9 @@ def train(): remove_columns="text", ) + # tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) + + # %% temp # tokenized_datasets['train'].rename_columns() @@ -525,7 +529,7 @@ def train(): per_device_eval_batch_size=64, auto_find_batch_size=False, ddp_find_unused_parameters=False, - weight_decay=0.01, + weight_decay=0.02, save_total_limit=1, num_train_epochs=40, warmup_steps=400, @@ -538,7 +542,7 @@ def train(): trainer = Trainer( model, training_args, - train_dataset=tokenized_datasets["train"], + train_dataset=tokenized_datasets['train'], tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, diff --git a/tackle_container/.gitignore b/tackle_container/.gitignore new file mode 100644 index 0000000..fd7e5dc --- /dev/null +++ b/tackle_container/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +checkpoint \ No newline at end of file diff --git a/tackle_container/biomedical_train.py b/tackle_container/biomedical_train.py new file mode 100644 index 0000000..9893625 --- /dev/null +++ b/tackle_container/biomedical_train.py @@ -0,0 +1,218 @@ +import os, random +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +import torch.multiprocessing as mp + +from transformers import AutoTokenizer, AutoModel + +from data import generate_train_entity_sets + +from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +import numpy as np +import logging + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + +def cleanup(): + dist.destroy_process_group() + +def train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size=32, pin_memory=True, num_workers=8): + dataset = generate_train_entity_sets(vocab_entity_id_mentions, entity_id_name=None, group_size=num_sample_per_class, anchor=False) + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) + return DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler) + +def test_dataloader(test_mentions, batch_size=32): + return DataLoader(test_mentions, batch_size=batch_size, shuffle=False) + +def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, margin): + # DEVICE = torch.device(f"cuda:{dist.get_rank()}") + DEVICE = torch.device(f'cuda:{rank}') + model.train() + epoch_loss, epoch_len = [epoch], [epoch] + + for groups in tqdm(train_dataloader, desc =f'Training batches on {DEVICE}'): + groups[0][:] = zip(*groups[0][::-1]) + x, y = [], [] + for mention, label in zip(groups[0], groups[1]): + mention = [m for m in mention if m != 'PAD'] + x.extend(mention) + y.extend([label.item()]*len(mention)) + + optimizer.zero_grad() + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs = inputs.to(DEVICE) + cls = model(inputs) +# cls = torch.nn.functional.normalize(cls) ## normalize cls embedding before computing loss, didn't work +# cls = torch.nn.Dropout(p= 0.25)(cls) ## add dropout, didn't work + +# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True) +# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=True) + + if epoch < epochs / 2: +# if epoch // (epochs / 4) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage +# if (epoch // 10) % 2 == 0: ## various ways of alternating batch all and batch hard, no obvious advantage + loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) + else: + loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) + + #### tried circle loss, no obvious advantage + + loss.backward() + optimizer.step() + # logging.info(f'{epoch} {len(x)} {loss.item()}') + epoch_loss.append(loss.item()) + epoch_len.append(len(x)) + # del inputs, cls, loss + # torch.cuda.empty_cache() + logging.info(f'{DEVICE}{epoch_len}') + logging.info(f'{DEVICE}{epoch_loss}') + +def check_label(predicted_cui: str, golden_cui: str) -> int: + """ + Some composite annotation didn't consider orders + So, set label '1' if any cui is matched within composite cui (or single cui) + Otherwise, set label '0' + """ + return int(len(set(predicted_cui.replace('+', '|').split("|")).intersection(set(golden_cui.replace('+', '|').split("|"))))>0) + +def getEmbeddings(mentions, model, tokenizer, DEVICE, batch_size=200): + model.to(DEVICE) + model.eval() + dataloader = DataLoader(mentions, batch_size, shuffle=False) + embeddings = np.empty((0, 768), np.float32) + with torch.no_grad(): + for mentions in tqdm(dataloader, desc ='Getting embeddings'): + inputs = tokenizer(mentions, padding=True, return_tensors='pt') + inputs = inputs.to(DEVICE) + cls = model(inputs) + embeddings = np.append(embeddings, cls.detach().cpu().numpy(), axis=0) + # del inputs, cls + # torch.cuda.empty_cache() + return embeddings + +def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, model, tokenizer): + DEVICE = torch.device(f'cuda:{rank}') + vocab_embeddings = getEmbeddings(vocab_mentions, model, tokenizer, DEVICE) + test_embeddings = getEmbeddings(test_mentions, model, tokenizer, DEVICE) + + knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(vocab_embeddings, vocab_ids) + n_neighbors = [1, 3, 5, 10] + res = [] + for n in n_neighbors: + distances, indices = knn.kneighbors(test_embeddings, n_neighbors=n) + num = 0 + for gold_cui, idx in zip(test_cuis, indices): + candidates = [id_to_cui[vocab_ids[i]] for i in idx] + for c in candidates: + if check_label(c, gold_cui): + num += 1 + break + res.append(num / len(test_cuis)) + # print(f'Top-{n:<3} accuracy: {num / len(test_cuis)}') + return res + # print(np.min(distances), np.max(distances)) + +def save_checkpoint(model, res, epoch, dataName): + logging.info(f'Saving model {epoch} {res} ') + torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt') + +class Model(nn.Module): + def __init__(self,MODEL_NAME): + super(Model, self).__init__() + self.model = AutoModel.from_pretrained(MODEL_NAME) + + def forward(self, inputs): + outputs = self.model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + return cls + +def main(rank, world_size, config): + + print(f"Running main(**args) on rank {rank}.") + setup(rank, world_size) + + dataName = config['DEFAULT']['dataName'] + logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO) + + vocab = defaultdict(set) + with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f: + for line in f: + vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower()) + + cui_to_id, id_to_cui = {}, {} + vocab_entity_id_mentions = {} + for id, cui in enumerate(vocab): + cui_to_id[cui] = id + id_to_cui[id] = cui + for cui, mention in vocab.items(): + vocab_entity_id_mentions[cui_to_id[cui]] = mention + + vocab_mentions, vocab_ids = [], [] + for id, mentions in vocab_entity_id_mentions.items(): + vocab_mentions.extend(mentions) + vocab_ids.extend([id]*len(mentions)) + + test_mentions, test_cuis = [], [] + with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f: + for line in f: + test_cuis.append(line.strip().split('||')[-1]) + test_mentions.append(line.strip().split('||')[-2].lower()) + + num_sample_per_class = int(config['data']['group_size']) # samples in each group + batch_size = int(config['train']['batch_size']) # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class + margin = int(config['model']['margin']) + epochs = int(config['train']['epochs']) + lr = float(config['train']['lr']) + MODEL_NAME = config['model']['model_name'] + + trainDataLoader = train_dataloader(vocab_entity_id_mentions, num_sample_per_class, rank, world_size, batch_size, pin_memory=False, num_workers=0) + # test_dataloader = test_dataloader(test_mentions, batch_size=200) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = Model(MODEL_NAME).to(rank) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) + + best = 0 + if rank == 0: + logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' ) + + for epoch in tqdm(range(epochs)): + trainDataLoader.sampler.set_epoch(epoch) + train(rank, epoch, epochs, trainDataLoader, ddp_model, optimizer, tokenizer, margin) + # if rank == 0 and epoch % 2 == 0: + if rank == 0: + res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer) + logging.info(f'{epoch} {res}') + if res[0] > best: + best = res[0] + save_checkpoint(ddp_model.module, res, epoch, dataName) + dist.barrier() + cleanup() + +if __name__ == '__main__': + import configparser + config = configparser.ConfigParser() + config.read('config.ini') + world_size = torch.cuda.device_count() + print(f"You have {world_size} GPUs.") + mp.spawn( + main, + args=(world_size, config), + nprocs=world_size, + join=True + ) \ No newline at end of file diff --git a/tackle_container/data.py b/tackle_container/data.py new file mode 100644 index 0000000..6f055f5 --- /dev/null +++ b/tackle_container/data.py @@ -0,0 +1,35 @@ +import random +def generate_train_entity_sets(entity_id_mentions, entity_id_name=None, group_size=10, anchor=False): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + mentions = list(mentions) + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + if entity_id_name: + group = list(set([entity_id_name[id]] + mentions)) + else: + group = list(mentions) + if len(group) == 1: + group.append(group[0]) + group.extend((group_size-len(group))%group_size * ['PAD']) + random.shuffle(group) + positives = [(group[i:i + group_size], id) for i in range(0, len(group), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + t[0] = [e for e in t[0] if e != 'PAD'] + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y \ No newline at end of file diff --git a/tackle_container/esAppMod_infer.py b/tackle_container/esAppMod_infer.py new file mode 100644 index 0000000..115dbff --- /dev/null +++ b/tackle_container/esAppMod_infer.py @@ -0,0 +1,106 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +state_dict = torch.load('./checkpoint/siamese_simple.pt') + +# Step 3: Apply the state dictionary to the model +model.load_state_dict(state_dict) +model.to(DEVICE) +model.eval() + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [d['mention'] for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/tackle_container/esAppMod_train.py b/tackle_container/esAppMod_train.py new file mode 100644 index 0000000..c5b4953 --- /dev/null +++ b/tackle_container/esAppMod_train.py @@ -0,0 +1,92 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) + +model.to(DEVICE) +model.train() + +losses = [] + +for epoch in tqdm(range(epochs)): + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x, y in batchGenerator(data, batch_size): + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs = inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) + loss.backward() + optimizer.step() + # print(epoch, loss) + losses.append(loss) + del inputs, outputs, cls, loss + torch.cuda.empty_cache() + +torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') diff --git a/tackle_container/esAppMod_train_ddp.py b/tackle_container/esAppMod_train_ddp.py new file mode 100644 index 0000000..442c219 --- /dev/null +++ b/tackle_container/esAppMod_train_ddp.py @@ -0,0 +1,242 @@ +# %% +import torch +import json +import random +import numpy as np +import pandas as pd +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +# parallel utilities +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import os + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +# produces a dictionary map from entity_id to entity_name +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +# map from entity_id to list of mentions +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +# map from entity_id to list of entity_names +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + + # entity sets are always ([list of mentions], id) + # to convert it to dataset form, we will just use a dataframe + id_mention_pairs = [] + for entity in entity_sets: + entity_id = entity[1] + for mention in entity[0]: + id_mention_pairs.append({ + 'entity_id': entity_id, + 'mention': mention + }) + df = pd.DataFrame(id_mention_pairs) + return df + +# def batchGenerator(data, batch_size): +# for i in range(0, len(data), batch_size): +# batch = data[i:i+batch_size] +# x, y = [], [] +# for t in batch: +# x.extend(t[0]) +# y.extend([t[1]]*len(t[0])) +# yield x, y + + +class CustomDataset(Dataset): + def __init__(self, df): + self.data = df # data should be preprocessed if necessary before being passed here + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + # Return the data and label as tuples + entry = self.data.iloc[idx] + x = entry['mention'] + y = entry['entity_id'] + return x,y + + + + +# %% + +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' #'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) + +# model.to(DEVICE) +# model.train() +# +# losses = [] + +# %% +# for epoch in tqdm(range(epochs)): +# data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) +# dataset = CustomDataset(data) +# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) +# for x, y in dataloader: +# # print(len(x), len(y), end='-->') +# optimizer.zero_grad() +# inputs = tokenizer(x, padding=True, return_tensors='pt') +# inputs = inputs.to(DEVICE) +# outputs = model(**inputs) +# cls = outputs.last_hidden_state[:,0,:] +# # for training less than half the time, train on easy +# if epoch < epochs / 2: +# loss, _ = batch_all_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) +# # for training after half the time, train on hard +# else: +# loss = batch_hard_triplet_loss(torch.tensor(y).to(DEVICE), cls, margin, squared=False) +# loss.backward() +# optimizer.step() +# # print(epoch, loss) +# losses.append(loss) +# del inputs, outputs, cls, loss +# torch.cuda.empty_cache() +# +# torch.save(model.state_dict(), './checkpoint/siamese.pt') +# %% + +def save_checkpoint(model, optimizer, epoch, path, rank): + if rank == 0: # Only save on the master process + # Save only the underlying model's state_dict, not the DDP wrapper + torch.save({ + 'model_state_dict': model.module.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'epoch': epoch + }, path) + +# %% +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def reduce_mean(tensor, nprocs): + """ + Reduces and averages the tensor across all processes. + This function reduces a tensor from all processes to all processes. + The resulting tensor is identical in all processes. + """ + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= nprocs + return rt + +def train(rank, world_size): + setup(rank, world_size) + + # Setup model, DataLoader with DistributedSampler + + model = AutoModel.from_pretrained(MODEL_NAME) + model = model.cuda(rank) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) + + # initialize progress bar + # Initialize tqdm on the master process + if torch.distributed.get_rank() == 0: # Only print from the master process + pbar = tqdm(total=epochs, desc='batch progress') + + + + for epoch in range(epochs): + total_loss = 0.0 + num_batches = 0 + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + train_dataset = CustomDataset(data) + train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) + train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler) + + device = torch.device(f"cuda:{rank}") + train_sampler.set_epoch(epoch) + for data, targets in train_loader: + # data, targets = data.cuda(rank), targets.cuda(rank) + optimizer.zero_grad() + inputs = tokenizer(data, padding=True, return_tensors='pt') + inputs = inputs.to(device) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(targets.to(device), cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(targets.to(device), cls, margin, squared=False) + + + loss.backward() + optimizer.step() + + # Reduce and average the loss across all processes + reduced_loss = reduce_mean(loss, world_size) + total_loss += reduced_loss.item() + num_batches += 1 + + # print(epoch, loss) + # losses.append(loss) + del inputs, outputs, cls, loss + torch.cuda.empty_cache() + + dist.barrier() + # Close tqdm bar on master process + if torch.distributed.get_rank() == 0: # Only print from the master process + pbar.update(epoch) + epoch_loss = total_loss / num_batches + tqdm.write(f'loss: {epoch_loss}') + + if torch.distributed.get_rank() == 0: # Only print from the master process + pbar.close() + path = './checkpoint/siamese.pt' + torch.save(model.module.state_dict(), path) + + cleanup() + + +if __name__ == '__main__': + # Set the number of processes to the number of GPUs available + world_size = torch.cuda.device_count() + # Use torch.multiprocessing.spawn to launch the processes + torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True) \ No newline at end of file diff --git a/tackle_container/loss.py b/tackle_container/loss.py new file mode 100644 index 0000000..581a1bb --- /dev/null +++ b/tackle_container/loss.py @@ -0,0 +1,186 @@ +# stardard functionalities for computing triplet loss, borrow code from +# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py +import torch +import torch.nn.functional as F +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + Args: + embeddings: tensor of shape (batch_size, embed_dim) + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + pairwise_distances: tensor of shape (batch_size, batch_size) + """ + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + + # Because of computation errors, some distances might be negative so we put everything >= 0.0 + distances[distances < 0] = 0 + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + mask = distances.eq(0).float() + distances = distances + mask * 1e-16 + + distances = (1.0 -mask) * torch.sqrt(distances) + + return distances + +def _get_triplet_mask(labels): + """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. + A triplet (i, j, k) is valid if: + - i, j, k are distinct + - labels[i] == labels[j] and labels[i] != labels[k] + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + """ + # Check that i, j and k are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + i_not_equal_j = indices_not_equal.unsqueeze(2) + i_not_equal_k = indices_not_equal.unsqueeze(1) + j_not_equal_k = indices_not_equal.unsqueeze(0) + + distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k + + + label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + i_equal_j = label_equal.unsqueeze(2) + i_equal_k = label_equal.unsqueeze(1) + + valid_labels = ~i_equal_k & i_equal_j + + return valid_labels & distinct_indices + + +def _get_anchor_positive_triplet_mask(labels): + """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check that i and j are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + + # Check if labels[i] == labels[j] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + + return labels_equal & indices_not_equal + + +def _get_anchor_negative_triplet_mask(labels): + """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check if labels[i] != labels[k] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + + return ~(labels.unsqueeze(0) == labels.unsqueeze(1)) + + +# Cell +def batch_hard_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + margin + tl = F.relu(tl) + triplet_loss = tl.mean() + + return triplet_loss + +# Cell +def batch_all_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + margin + + + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets \ No newline at end of file diff --git a/tackle_container/results/bert-base.txt b/tackle_container/results/bert-base.txt new file mode 100644 index 0000000..fe44b57 --- /dev/null +++ b/tackle_container/results/bert-base.txt @@ -0,0 +1,5 @@ +Top-1 accuracy: 0.6974169741697417 +Top-3 accuracy: 0.8126281262812628 +Top-5 accuracy: 0.8413284132841329 +Top-10 accuracy: 0.8720787207872078 +0.005117357 0.74772596 diff --git a/tackle_container/results/bert-small_1.txt b/tackle_container/results/bert-small_1.txt new file mode 100644 index 0000000..9ca9c53 --- /dev/null +++ b/tackle_container/results/bert-small_1.txt @@ -0,0 +1,5 @@ +Top-1 accuracy: 0.8019680196801968 +Top-3 accuracy: 0.8901189011890119 +Top-5 accuracy: 0.9085690856908569 +Top-10 accuracy: 0.9249692496924969 +0.0 0.7323234 diff --git a/tackle_container/results/bert-small_2.txt b/tackle_container/results/bert-small_2.txt new file mode 100644 index 0000000..94f4979 --- /dev/null +++ b/tackle_container/results/bert-small_2.txt @@ -0,0 +1,5 @@ +Top-1 accuracy: 0.8163181631816319 +Top-3 accuracy: 0.8987289872898729 +Top-5 accuracy: 0.9167691676916769 +Top-10 accuracy: 0.9356293562935629 +0.0 0.7410505 diff --git a/tackle_container/results/bert-small_3.txt b/tackle_container/results/bert-small_3.txt new file mode 100644 index 0000000..5f013cc --- /dev/null +++ b/tackle_container/results/bert-small_3.txt @@ -0,0 +1,5 @@ +Top-1 accuracy: 0.7908979089790897 +Top-3 accuracy: 0.8888888888888888 +Top-5 accuracy: 0.914309143091431 +Top-10 accuracy: 0.931119311193112 +0.0 0.7351225