From 182760b7a23c95260dea7f7268aaefd18094355a Mon Sep 17 00:00:00 2001 From: Richard Wong Date: Thu, 23 Jan 2025 20:52:55 +0900 Subject: [PATCH] added experiments with triplet loss and augmentations - includes experiments on character-level bert --- analysis/error_analysis_baseline.py | 56 + analysis/graph_top1_curves.py | 59 + cosines_with_augmentations/baseline_infer.py | 125 ++ cosines_with_augmentations/baseline_train.py | 277 +++ .../{esAppMod_infer.py => classify_infer.py} | 2 +- .../{classify.py => classify_logits.py} | 4 +- cosines_with_augmentations/classify_train.py | 316 +++ cosines_with_augmentations/esAppMod_train.py | 49 +- .../esAppMod_train_with_classification.py | 16 +- cosines_with_augmentations/hybrid_infer.py | 124 ++ cosines_with_augmentations/hybrid_train.py | 315 +++ .../results/classify.csv | 1740 ++++++++--------- cosines_with_augmentations/results/output.txt | 10 +- cosines_with_augmentations/study_sampling.py | 270 +++ .../understanding_loss.py | 378 ++++ experimental/.gitignore | 4 + experimental/character_bert_train.py | 577 ++++++ experimental/loss.py | 288 +++ experimental/pretrain_character_bert_train.py | 574 ++++++ .../.gitignore | 4 + .../baseline_infer.py | 131 ++ .../baseline_train.py | 400 ++++ .../batch_all_train.py | 424 ++++ .../character_bert_train.py | 561 ++++++ .../classify_infer.py | 124 ++ .../classify_logits.py | 258 +++ .../classify_train.py | 315 +++ .../esAppMod_train.py | 277 +++ .../esAppMod_train_with_classification.py | 315 +++ .../hybrid_infer.py | 124 ++ .../hybrid_train.py | 433 ++++ loss_comparisons_with_augmentations/loss.py | 288 +++ .../.gitignore | 4 + .../baseline_infer.py | 132 ++ .../baseline_train.py | 382 ++++ .../classify_infer.py | 124 ++ .../classify_logits.py | 258 +++ .../classify_train.py | 316 +++ .../esAppMod_train.py | 277 +++ .../esAppMod_train_with_classification.py | 315 +++ .../hybrid_infer.py | 124 ++ .../hybrid_train.py | 315 +++ loss_comparisons_without_augmentation/loss.py | 193 ++ reference_code/character_bert_train.py | 460 +++++ 44 files changed, 10834 insertions(+), 904 deletions(-) create mode 100644 analysis/error_analysis_baseline.py create mode 100644 analysis/graph_top1_curves.py create mode 100644 cosines_with_augmentations/baseline_infer.py create mode 100644 cosines_with_augmentations/baseline_train.py rename cosines_with_augmentations/{esAppMod_infer.py => classify_infer.py} (100%) rename cosines_with_augmentations/{classify.py => classify_logits.py} (97%) create mode 100644 cosines_with_augmentations/classify_train.py create mode 100644 cosines_with_augmentations/hybrid_infer.py create mode 100644 cosines_with_augmentations/hybrid_train.py create mode 100644 cosines_with_augmentations/study_sampling.py create mode 100644 cosines_with_augmentations/understanding_loss.py create mode 100644 experimental/.gitignore create mode 100644 experimental/character_bert_train.py create mode 100644 experimental/loss.py create mode 100644 experimental/pretrain_character_bert_train.py create mode 100644 loss_comparisons_with_augmentations/.gitignore create mode 100644 loss_comparisons_with_augmentations/baseline_infer.py create mode 100644 loss_comparisons_with_augmentations/baseline_train.py create mode 100644 loss_comparisons_with_augmentations/batch_all_train.py create mode 100644 loss_comparisons_with_augmentations/character_bert_train.py create mode 100644 loss_comparisons_with_augmentations/classify_infer.py create mode 100644 loss_comparisons_with_augmentations/classify_logits.py create mode 100644 loss_comparisons_with_augmentations/classify_train.py create mode 100644 loss_comparisons_with_augmentations/esAppMod_train.py create mode 100644 loss_comparisons_with_augmentations/esAppMod_train_with_classification.py create mode 100644 loss_comparisons_with_augmentations/hybrid_infer.py create mode 100644 loss_comparisons_with_augmentations/hybrid_train.py create mode 100644 loss_comparisons_with_augmentations/loss.py create mode 100644 loss_comparisons_without_augmentation/.gitignore create mode 100644 loss_comparisons_without_augmentation/baseline_infer.py create mode 100644 loss_comparisons_without_augmentation/baseline_train.py create mode 100644 loss_comparisons_without_augmentation/classify_infer.py create mode 100644 loss_comparisons_without_augmentation/classify_logits.py create mode 100644 loss_comparisons_without_augmentation/classify_train.py create mode 100644 loss_comparisons_without_augmentation/esAppMod_train.py create mode 100644 loss_comparisons_without_augmentation/esAppMod_train_with_classification.py create mode 100644 loss_comparisons_without_augmentation/hybrid_infer.py create mode 100644 loss_comparisons_without_augmentation/hybrid_train.py create mode 100644 loss_comparisons_without_augmentation/loss.py create mode 100644 reference_code/character_bert_train.py diff --git a/analysis/error_analysis_baseline.py b/analysis/error_analysis_baseline.py new file mode 100644 index 0000000..ba61825 --- /dev/null +++ b/analysis/error_analysis_baseline.py @@ -0,0 +1,56 @@ +# %% +import pandas as pd +import json + +# %% +data_path = '../loss_comparisons_without_augmentation/results/predictions.txt' +df = pd.read_csv(data_path, header=None) +df = df.rename(columns={0: 'actual', 1: 'predicted'}) + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +df['predicted_name'] = df['predicted'].map(all_entity_id_name) + +# %% +# import test file +data_path = '../esAppMod_data_import/test.csv' +# data_path = '../esAppMod_data_import/parent_test.csv' +test_df = pd.read_csv(data_path) + + + +# %% +df_out = pd.concat([test_df,df], axis=1) + +# %% +mask1 = (df['predicted'] != df['actual']) +# %% + +print(df_out[mask1].sort_values(by=['entity_id']).to_markdown()) +# %% + +data_path = '../loss_comparisons_with_augmentations/results/predictions.txt' +df2 = pd.read_csv(data_path, header=None) +df2 = df2.rename(columns={0: 'actual', 1: 'predicted'}) +mask2 = df2['actual'] != df2['predicted'] + + +# %% +# i want to find entries that were: +# - correct in mask1 +# - wrong in mask2 +mask_left = ~mask1 & mask2 + +predicted_entity = df2['predicted'].map(all_entity_id_name) +df_out = pd.concat([test_df,df2, predicted_entity], axis=1) +print(df_out[mask_left].sort_values(by=['entity_id']).to_markdown()) +# %% diff --git a/analysis/graph_top1_curves.py b/analysis/graph_top1_curves.py new file mode 100644 index 0000000..434277c --- /dev/null +++ b/analysis/graph_top1_curves.py @@ -0,0 +1,59 @@ +# %% +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +# %% +data_path = '../loss_comparisons_without_augmentation/top1_curves/baseline_output.txt' +df = pd.read_csv(data_path, header=None) +y = df[0] +plt.plot(y) + +# Find the max value +max_y = np.max(y) # Max value +max_x = np.argmax(y) # x value corresponding to the max y +# Annotate the max value on the plot +# plt.annotate(f'Max: {max_y:.5f}', # Text to display +# xy=(max_x, max_y), # Point to annotate +# xytext=(max_x+0.7, max_y-0.3), # Location of text +# arrowprops=dict(facecolor='black',arrowstyle='->'), +# bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='yellow')) + + +# data_path = '../experimental/top1_curves/character_output.txt' +# df = pd.read_csv(data_path, header=None) +# y = df[0] +# plt.plot(y) +# max_y = np.max(y) # Max value +# max_x = np.argmax(y) # x value corresponding to the max y +# # Annotate the max value on the plot +# plt.annotate(f'Max: {max_y:.5f}', # Text to display +# xy=(max_x, max_y), # Point to annotate +# xytext=(max_x+0.7, max_y-0.2), # Location of text +# arrowprops=dict(facecolor='black',arrowstyle='->'), +# bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='yellow')) + +data_path = '../experimental/top1_curves/character_knn.txt' +df = pd.read_csv(data_path, header=None) +y = df[0] +plt.plot(y) +max_y = np.max(y) # Max value +max_x = np.argmax(y) # x value corresponding to the max y +# Annotate the max value on the plot +plt.annotate(f'Max: {max_y:.5f}', # Text to display + xy=(max_x, max_y), # Point to annotate + xytext=(max_x+0.7, max_y-0.4), # Location of text + arrowprops=dict(facecolor='black',arrowstyle='->'), + bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='yellow')) + +plt.ylim(0.4,1) + +# data_path = '../loss_comparisons_with_augmentations/top1_curves/smooth_output.txt' +# df = pd.read_csv(data_path, header=None) +# plt.plot(df[0]) + + + + + +# %% diff --git a/cosines_with_augmentations/baseline_infer.py b/cosines_with_augmentations/baseline_infer.py new file mode 100644 index 0000000..7ed2c67 --- /dev/null +++ b/cosines_with_augmentations/baseline_infer.py @@ -0,0 +1,125 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +# state_dict = torch.load('./checkpoint/classification.pt') +state_dict = torch.load('./checkpoint/baseline.pt') +# params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(state_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/cosines_with_augmentations/baseline_train.py b/cosines_with_augmentations/baseline_train.py new file mode 100644 index 0000000..72d4089 --- /dev/null +++ b/cosines_with_augmentations/baseline_train.py @@ -0,0 +1,277 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim + + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # all augmentations disabled + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, outputs, cls, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch % 5 == 0: + torch.save(model.state_dict(), './checkpoint/baseline.pt') + + +torch.save(model.state_dict(), './checkpoint/baseline.pt') +# %% diff --git a/cosines_with_augmentations/esAppMod_infer.py b/cosines_with_augmentations/classify_infer.py similarity index 100% rename from cosines_with_augmentations/esAppMod_infer.py rename to cosines_with_augmentations/classify_infer.py index 0d8d60e..3109739 100644 --- a/cosines_with_augmentations/esAppMod_infer.py +++ b/cosines_with_augmentations/classify_infer.py @@ -15,8 +15,8 @@ import gc # Step 2: Load the state dictionary DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' -# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) diff --git a/cosines_with_augmentations/classify.py b/cosines_with_augmentations/classify_logits.py similarity index 97% rename from cosines_with_augmentations/classify.py rename to cosines_with_augmentations/classify_logits.py index 889aa81..d09558b 100644 --- a/cosines_with_augmentations/classify.py +++ b/cosines_with_augmentations/classify_logits.py @@ -109,7 +109,9 @@ def test(): # prepare tokenizer - MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + # MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + # MODEL_NAME = 'distilbert-base-cased' + MODEL_NAME = 'prajjwal1/bert-small' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True) # Define additional special tokens # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] diff --git a/cosines_with_augmentations/classify_train.py b/cosines_with_augmentations/classify_train.py new file mode 100644 index 0000000..77ddb70 --- /dev/null +++ b/cosines_with_augmentations/classify_train.py @@ -0,0 +1,316 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # disable augmentations + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + # y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + # if epoch < epochs / 2: + # triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # # for training after half the time, train on hard + # else: + # triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss # + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/classification.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/classification.pt') +# %% diff --git a/cosines_with_augmentations/esAppMod_train.py b/cosines_with_augmentations/esAppMod_train.py index 88ceb6a..535a5c2 100644 --- a/cosines_with_augmentations/esAppMod_train.py +++ b/cosines_with_augmentations/esAppMod_train.py @@ -163,30 +163,31 @@ def augment_data(df): # add basic example output_list.append(create_example(index, parent_desc, entity_name)) - # add shuffled strings - processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) - for desc in processed_descs: - if (desc != parent_desc): - output_list.append(create_example(index, desc, entity_name)) + # all augmentations disabled + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) - # add corrupted strings - desc = corrupt_string(parent_desc, corruption_probability=0.01) - if (desc != parent_desc): - output_list.append(create_example(index, desc, entity_name)) + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) - # add example with stripped non-alphanumerics - desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces - if (desc != parent_desc): - output_list.append(create_example(index, desc, entity_name)) + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) - # short sequence amplifier - # short sequences are rare, and we must compensate by including more examples - # also, short sequence don't usually get affected by shuffle - words = parent_desc.split() - word_count = len(words) - if word_count <= 2: - for _ in range(AMPLIFY_FACTOR): - output_list.append(create_example(index, desc, entity_name)) + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) new_df = pd.DataFrame(output_list) return new_df @@ -215,7 +216,7 @@ num_sample_per_class = 10 # samples in each group batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = 2 epochs = 200 -DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') # MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' @@ -269,8 +270,8 @@ for epoch in tqdm(range(epochs)): print(f'epoch loss: {total_loss/batch_number}') # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") if epoch % 5 == 0: - torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') + torch.save(model.state_dict(), './checkpoint/baseline.pt') -torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') +torch.save(model.state_dict(), './checkpoint/baseline.pt') # %% diff --git a/cosines_with_augmentations/esAppMod_train_with_classification.py b/cosines_with_augmentations/esAppMod_train_with_classification.py index 70984e3..97864ef 100644 --- a/cosines_with_augmentations/esAppMod_train_with_classification.py +++ b/cosines_with_augmentations/esAppMod_train_with_classification.py @@ -217,8 +217,8 @@ batch_size = 16 # number of groups, effective batch_size for computing triplet l margin = 2 epochs = 200 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') -# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' -MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) bert_model = AutoModel.from_pretrained(MODEL_NAME) @@ -245,12 +245,22 @@ model.train() losses = [] +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + for epoch in tqdm(range(epochs)): total_loss = 0.0 batch_number = 0 + lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # Update optimizer's learning rate + for param_group in optimizer.param_groups: + param_group['lr'] = lr + augmented_df = augment_data(df) train_entity_id_mentions = make_entity_id_mentions(augmented_df) train_entity_id_name = make_entity_id_name(augmented_df) @@ -294,7 +304,7 @@ for epoch in tqdm(range(epochs)): # scheduler.step() # Update the learning rate print(f'epoch loss: {total_loss/batch_number}') - # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + print(f"Epoch {epoch+1}: lr={lr}") if epoch % 5 == 0: # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') torch.save(model.state_dict(), './checkpoint/classification.pt') diff --git a/cosines_with_augmentations/hybrid_infer.py b/cosines_with_augmentations/hybrid_infer.py new file mode 100644 index 0000000..94d69b5 --- /dev/null +++ b/cosines_with_augmentations/hybrid_infer.py @@ -0,0 +1,124 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +state_dict = torch.load('./checkpoint/hybrid.pt') +params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(params_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/cosines_with_augmentations/hybrid_train.py b/cosines_with_augmentations/hybrid_train.py new file mode 100644 index 0000000..74e8aa2 --- /dev/null +++ b/cosines_with_augmentations/hybrid_train.py @@ -0,0 +1,315 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + if epoch < epochs / 2: + triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/hybrid.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/hybrid.pt') +# %% diff --git a/cosines_with_augmentations/results/classify.csv b/cosines_with_augmentations/results/classify.csv index ff97a0c..4a1e1b0 100644 --- a/cosines_with_augmentations/results/classify.csv +++ b/cosines_with_augmentations/results/classify.csv @@ -41,19 +41,19 @@ class_prediction 497 394 394 -301 -355 +521 +130 485 -300 +453 486 -383 -299 -299 -498 +1 +593 +2 498 592 592 -438 +592 +568 592 592 592 @@ -68,12 +68,12 @@ class_prediction 4 4 418 -418 +580 5 -498 -81 -557 -626 +488 +6 +501 +501 418 7 8 @@ -85,19 +85,19 @@ class_prediction 259 259 259 -46 +470 +259 259 259 259 259 259 259 -576 259 259 9 375 -516 +10 11 12 12 @@ -133,17 +133,17 @@ class_prediction 260 260 376 -296 -498 -498 -557 +376 +130 +13 +223 261 14 299 299 -306 +15 301 -320 +486 600 600 600 @@ -159,7 +159,7 @@ class_prediction 600 600 600 -585 +486 600 600 600 @@ -167,15 +167,15 @@ class_prediction 600 600 600 -536 +547 600 600 600 600 600 600 -320 -487 +486 +402 600 600 600 @@ -183,15 +183,15 @@ class_prediction 600 600 17 -377 -517 +18 +19 20 20 20 589 302 -504 -428 +133 +542 21 307 582 @@ -201,72 +201,72 @@ class_prediction 306 306 306 -306 +434 306 22 23 23 -111 +23 24 25 307 307 307 307 -99 -111 -420 +28 +29 +29 29 30 30 30 -552 -296 30 30 -522 +30 +30 +535 563 563 30 -273 563 563 -500 563 -111 +564 +563 +564 563 563 563 32 32 -580 +538 30 -166 +455 309 594 594 36 36 36 -445 +36 37 37 37 311 37 -93 -99 -93 -99 -104 -93 -93 -99 -99 +603 +40 +40 +40 +40 +103 +40 +40 +40 296 41 42 -55 +42 312 312 312 @@ -275,9 +275,9 @@ class_prediction 520 43 43 +445 43 -43 -43 +93 43 43 520 @@ -285,8 +285,8 @@ class_prediction 43 43 43 -593 -43 +577 +93 43 43 43 @@ -301,25 +301,25 @@ class_prediction 43 43 383 +470 43 43 43 -43 -43 +470 503 -157 44 -157 -157 +44 +44 +44 45 -383 -383 +384 +314 456 585 626 -517 +460 +48 48 -383 49 49 49 @@ -328,37 +328,37 @@ class_prediction 316 596 596 -593 +318 +319 +319 +319 319 320 -320 -320 -320 -99 +43 422 51 -111 +52 52 322 -593 +53 263 263 55 55 +56 +583 +58 +58 583 -449 -449 -449 -449 449 59 293 -449 -449 -174 -449 +60 +60 +61 +61 57 -522 +543 522 327 327 @@ -367,39 +367,39 @@ class_prediction 62 62 457 -593 -445 +63 +461 64 -102 64 -445 -536 -520 -300 -99 -51 -51 +64 +65 +65 +295 +523 +294 +541 +541 328 328 265 -425 -265 -265 -285 -285 -522 -522 -522 -285 -43 -285 -285 -522 -445 -103 -438 424 -592 +265 +265 +285 +285 +265 +265 +265 +285 +265 +285 +285 +265 +134 +690 +690 +424 +329 330 330 67 @@ -407,91 +407,91 @@ class_prediction 68 68 68 -368 +609 68 -585 +458 572 -604 +459 70 70 -445 -102 -617 -609 +486 73 -609 -445 -525 -103 +73 +581 +73 +102 +73 +73 +592 +329 +355 424 -351 -424 -351 -103 -93 +155 +296 605 605 605 -93 605 605 -301 605 +605 +605 +605 +606 +459 +76 +76 +462 +462 +43 +462 +77 +520 +462 +520 +462 +462 +462 +462 +461 604 604 +604 +520 +520 +604 121 43 -76 -443 -43 593 -657 -520 -576 -520 -581 -593 -580 -443 -442 +603 604 -604 -604 -43 -520 -604 -43 -520 -370 -562 -604 -504 +462 463 463 -300 -251 -285 -443 +224 +79 +80 +589 593 81 81 -248 81 81 +81 +81 +109 +81 +81 +81 +81 +81 +109 +81 248 81 -248 -81 -248 -248 -81 -81 -81 -81 -248 -81 -316 -285 +82 +320 +83 609 609 609 @@ -504,7 +504,13 @@ class_prediction 609 609 609 -497 +107 +609 +408 +107 +609 +609 +609 609 609 107 @@ -542,35 +548,29 @@ class_prediction 609 609 609 -609 -609 -609 -609 -609 -609 107 -609 -609 489 489 -99 -617 +489 +489 +107 +490 355 84 85 85 -106 -106 -530 +86 +86 +87 +87 +87 +87 +87 +87 +87 +87 +87 87 -593 -593 -530 -593 -530 -593 -593 -530 88 296 296 @@ -586,35 +586,35 @@ class_prediction 584 584 584 -390 +334 589 -590 -584 584 584 584 584 +334 +334 333 333 -482 -333 -482 -482 -482 -482 +506 +396 +506 +506 +506 +506 378 378 -383 -327 -327 +379 +380 +380 381 382 383 383 383 383 -323 -55 +385 +385 593 388 388 @@ -624,26 +624,26 @@ class_prediction 388 388 388 -584 +388 388 388 388 1 333 333 -334 -390 +442 +389 334 334 390 390 391 -333 335 335 335 335 -589 +336 +397 583 394 333 @@ -655,41 +655,41 @@ class_prediction 397 397 398 -398 +399 398 402 402 403 390 -507 +385 589 589 589 406 406 -443 +407 408 409 -111 +593 409 411 -411 -411 +412 +412 413 413 413 -589 -520 -268 +336 +415 268 268 +434 536 -268 -268 -268 -268 -268 -617 +493 +493 +493 +493 +493 +670 268 268 268 @@ -700,22 +700,22 @@ class_prediction 268 492 338 -334 -334 -4 -4 -111 +339 +339 +91 +91 +91 92 +593 +576 +593 +576 +593 +593 +593 576 576 -576 -576 -576 -576 -576 -576 -576 -576 +593 576 427 427 @@ -748,10 +748,10 @@ class_prediction 427 427 428 -576 +477 428 428 -576 +437 429 429 429 @@ -762,40 +762,40 @@ class_prediction 430 431 431 -576 432 432 +593 453 593 593 +593 +593 +593 +593 +432 576 593 593 432 -432 -432 +610 432 593 432 -134 -437 -296 -432 -432 +593 593 432 432 593 -576 +593 432 -434 -576 +432 +593 432 433 -248 -434 +433 434 434 +268 434 434 434 @@ -808,17 +808,17 @@ class_prediction 268 434 434 +434 +434 +434 268 434 434 434 434 434 -434 -434 -434 -434 -576 +268 +432 434 434 434 @@ -849,7 +849,7 @@ class_prediction 434 434 434 -434 +593 434 434 434 @@ -864,7 +864,7 @@ class_prediction 434 434 434 -434 +268 434 434 434 @@ -892,9 +892,9 @@ class_prediction 434 434 434 -268 434 434 +593 434 434 434 @@ -911,15 +911,15 @@ class_prediction 434 268 434 -434 -576 +268 +355 434 434 268 434 434 434 -576 +434 434 434 434 @@ -929,11 +929,11 @@ class_prediction 434 434 434 +268 +268 434 434 -576 -434 -434 +388 434 434 434 @@ -955,7 +955,7 @@ class_prediction 434 434 434 -268 +434 434 434 434 @@ -982,7 +982,7 @@ class_prediction 434 434 434 -434 +268 434 434 434 @@ -1005,21 +1005,11 @@ class_prediction 431 435 435 -437 -435 -576 -437 -435 -431 435 435 -437 +593 435 435 -296 -435 -431 -431 431 435 435 @@ -1028,16 +1018,26 @@ class_prediction 435 435 435 -576 +435 +435 +435 +435 +435 +435 +435 +435 +435 +435 +435 431 -437 -435 -437 435 435 -443 435 -576 +435 +435 +435 +435 +435 435 436 436 @@ -1053,12 +1053,12 @@ class_prediction 93 93 93 -375 +271 438 -584 +364 95 95 -383 +96 97 98 99 @@ -1079,44 +1079,44 @@ class_prediction 103 104 105 -121 +468 106 107 107 107 108 110 -603 -603 -603 -576 -603 -603 -603 -589 +110 +110 +110 +437 +110 +110 +110 +438 110 111 111 111 112 -390 +112 107 -497 -576 -497 +113 +113 +113 117 -596 +665 114 114 -522 -593 +677 +115 406 -300 +115 115 116 -621 +587 116 -621 +176 117 117 117 @@ -1151,7 +1151,8 @@ class_prediction 581 581 581 -43 +93 +121 581 581 581 @@ -1190,6 +1191,7 @@ class_prediction 581 581 581 +121 581 581 581 @@ -1198,6 +1200,7 @@ class_prediction 581 581 581 +572 581 581 581 @@ -1206,7 +1209,6 @@ class_prediction 581 581 581 -296 581 581 581 @@ -1223,12 +1225,14 @@ class_prediction 581 581 581 +121 581 581 581 581 581 581 +121 581 581 581 @@ -1252,7 +1256,6 @@ class_prediction 581 581 581 -525 581 581 581 @@ -1269,58 +1272,55 @@ class_prediction 581 581 581 -581 -581 -593 121 121 466 -104 -581 -581 +467 +581 +581 +468 +581 +581 +468 581 581 581 +468 577 -581 -581 -581 -556 520 -520 -581 +469 581 470 470 581 470 -581 +121 470 470 -111 -111 471 -111 -111 -581 -581 +471 +471 +121 +471 +472 +473 +473 +473 +473 +473 +473 +473 +333 +579 +473 473 -581 -581 -581 -581 -581 -383 -383 -581 -496 441 -593 +441 +441 443 441 441 441 -593 441 122 122 @@ -1333,22 +1333,22 @@ class_prediction 122 123 123 -497 -568 -263 -597 +297 +272 +272 +438 273 -111 -561 -576 -576 -576 +272 +124 274 274 -388 -587 +274 +274 +274 +342 +342 125 -507 +437 507 507 507 @@ -1370,35 +1370,36 @@ class_prediction 344 344 344 -128 -449 +126 +197 442 442 442 442 -84 -368 +383 +130 131 -134 -107 +610 +536 +610 +259 +453 134 134 -333 +610 +610 134 +610 134 -134 -535 -541 -134 -134 -594 -134 +610 +593 600 -134 +132 343 -88 +107 343 134 +303 134 134 134 @@ -1409,13 +1410,15 @@ class_prediction 134 134 134 -134 -134 +583 134 134 134 445 -593 +134 +134 +134 +298 134 134 134 @@ -1428,12 +1431,7 @@ class_prediction 134 134 134 -134 -134 -134 -594 -134 -134 +525 134 134 134 @@ -1444,6 +1442,7 @@ class_prediction 134 134 134 +436 134 134 134 @@ -1452,6 +1451,7 @@ class_prediction 134 134 580 +580 134 134 134 @@ -1462,6 +1462,10 @@ class_prediction 134 134 134 +443 +134 +134 +333 134 134 134 @@ -1470,162 +1474,158 @@ class_prediction 134 134 134 +333 +134 +134 +134 +134 +437 134 134 134 134 134 +432 +383 134 134 -134 -134 -134 -134 -134 -134 -593 -134 -134 -134 -134 -134 -335 +474 475 -516 -383 -383 -383 -296 -134 -134 -593 +135 +136 +333 +136 +298 +298 +298 +138 139 134 -134 +610 141 -443 -443 -443 +593 +593 +593 143 143 -602 +447 143 -574 +144 +145 +146 134 -134 -134 -134 -134 -300 -503 +567 +567 +370 +148 276 495 149 416 -579 -536 +20 +309 151 -134 +593 152 585 348 348 417 -356 +349 153 586 154 -55 +365 155 351 351 351 352 -134 +572 352 -134 +572 352 352 352 352 352 352 -504 +56 157 157 157 157 158 158 +295 158 -158 -158 +148 159 353 353 -594 -161 -161 -161 -161 -589 -161 +160 +162 +162 161 +162 +162 +162 +162 163 587 587 587 -248 +396 165 -562 -167 +166 +166 168 169 278 356 171 122 -121 -174 -496 -431 +359 174 +301 +173 174 +476 477 437 -445 -593 +175 +175 279 -174 -174 -174 +296 +104 +177 178 178 178 179 360 361 -589 +45 180 603 -306 +281 603 603 603 581 603 -296 182 182 -623 +182 +183 184 -368 +185 186 -301 -581 -572 -593 -448 +187 +478 +478 +189 +281 281 281 281 @@ -1636,26 +1636,26 @@ class_prediction 190 190 190 -178 -579 +479 +480 190 191 -418 -443 -117 -593 -593 +194 +195 +196 +197 +198 199 200 -593 +201 202 203 -107 +204 205 206 -593 +207 208 -209 +443 210 211 212 @@ -1666,25 +1666,25 @@ class_prediction 215 217 217 -111 -84 +64 +219 481 218 219 -99 +220 222 366 -572 +648 366 366 -84 223 +460 223 224 224 -390 -390 -260 +282 +282 +282 445 445 445 @@ -1745,32 +1745,32 @@ class_prediction 447 447 448 -448 -448 -448 -134 -448 -438 -448 +168 +168 +168 593 -425 +448 +168 +448 +690 +690 448 593 448 448 448 134 -593 +448 +168 +448 +168 +448 +168 +448 +168 448 448 -448 -438 -438 -448 -448 -448 -448 -134 +168 448 449 449 @@ -1783,7 +1783,7 @@ class_prediction 449 449 449 -593 +449 449 449 449 @@ -1798,9 +1798,9 @@ class_prediction 368 369 369 -574 +226 +593 593 -425 228 370 370 @@ -1808,8 +1808,8 @@ class_prediction 370 370 370 -370 -370 +301 +371 370 370 111 @@ -1826,80 +1826,80 @@ class_prediction 568 568 568 -425 +559 111 -104 -568 +229 +230 231 -316 +318 321 -665 -600 -579 -596 -390 -589 +453 +233 +486 +547 +283 +283 +284 +284 +284 +284 284 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 -603 284 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 -285 285 284 285 +284 285 +284 +333 +284 +284 +284 +284 +284 +284 +284 +284 +284 +284 +284 +284 285 +284 +284 +284 +284 +284 +284 285 +284 +284 +284 +284 285 -285 -285 -285 -316 +284 284 431 285 590 -445 -285 -445 +590 285 +284 285 +284 285 431 -285 +234 486 285 -285 -285 -285 -285 -285 -285 -285 -285 +286 +286 +286 +286 +286 +286 +286 +287 237 580 580 @@ -1912,38 +1912,38 @@ class_prediction 580 580 609 -268 239 +452 580 452 580 +451 +580 +487 580 580 580 580 580 580 -580 -580 -580 -560 +30 580 580 580 580 +451 +580 +451 580 580 580 +242 580 580 580 -443 -580 -580 -580 -580 -580 +451 580 +451 580 580 580 @@ -1966,28 +1966,276 @@ class_prediction 452 452 452 -99 +276 452 452 452 452 +580 +580 +101 +452 +580 +452 +296 +580 +452 +452 +580 +580 +452 +452 +452 +452 +580 +580 +452 +452 +583 +452 +580 +452 +452 +452 +580 +452 +452 +452 +452 +452 +452 +452 +452 +486 +580 +334 +452 +580 +452 +452 +452 +452 +452 +452 +273 +580 +452 +580 +580 +452 +579 +580 +452 +452 +580 +452 +580 +409 +452 +452 +580 +452 +580 +452 +452 +452 +580 +452 +276 +452 +580 +409 +451 +334 +452 +452 +452 +452 +452 +276 +580 +276 +452 +452 +452 +452 +452 +276 +452 +409 +452 +580 +580 +452 +443 +580 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +55 +452 +452 +452 +276 +273 +452 +452 +296 +580 +296 +276 +452 +452 +452 +452 +580 +334 +276 +580 +580 +580 +452 +452 +452 +242 +452 +580 +452 +559 +452 +437 +452 +452 +242 +276 +580 +452 +580 +452 +580 +580 +452 +580 +452 +452 +580 +580 +409 +452 +580 +452 +452 +452 +486 +320 +580 +452 +452 +452 +452 +580 +452 +580 +452 +452 +580 +452 +580 +242 +580 +568 +111 +452 +452 +452 +242 +452 +296 +452 +242 +580 +452 +296 +580 +580 +580 +580 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +580 +242 +452 +452 +452 +452 +579 +580 +452 +452 +452 +580 +580 +580 +452 +452 +580 +581 +580 +559 +452 +486 +452 +452 +580 +452 +452 +579 +434 +452 +452 +452 +452 +388 +452 +452 +452 +242 +452 +452 +452 355 -452 -306 -452 -452 -452 -452 -580 -452 -452 -580 580 452 452 452 580 +452 580 +452 +452 +296 +452 +452 +452 +274 +452 580 452 452 @@ -1997,314 +2245,66 @@ class_prediction 452 452 452 +242 +452 +580 +580 +580 +580 +580 +452 +452 +452 +451 +242 580 452 452 452 452 452 -452 +580 452 452 334 580 580 452 -580 -452 -452 -452 -452 -452 -452 -452 -580 -452 -580 -580 -452 -452 -580 -452 -452 -580 -452 -580 -580 -452 -452 -580 -452 -580 -452 -452 -452 -580 -452 -452 -452 -580 -580 -452 -580 -452 -452 -452 -452 -452 -99 -580 -99 -452 -452 -452 -452 -452 -593 -452 -452 -452 -580 -580 -452 -452 -580 -452 -452 -452 -452 -452 -452 -452 -452 -452 -452 -580 -452 -452 -452 -43 -453 -452 -452 -452 -580 -306 -452 -452 -452 -452 -452 -580 -583 -452 -580 -43 -452 -452 -452 -452 -242 -452 -580 -580 -452 -452 -580 -452 -452 -242 -452 -580 -452 -580 -452 -580 -580 -452 -306 -452 -452 -580 -580 -452 -452 -580 -452 -580 -452 -596 -316 -103 -452 -452 -452 -452 -580 452 580 452 452 580 452 -580 -522 -580 -306 -242 -452 -580 -452 -355 -452 -580 -452 -522 -580 -452 -296 -580 -580 -580 -580 +334 452 452 452 452 452 452 -452 -452 -452 -452 -452 -452 -580 -107 -452 -452 -452 -452 -452 -580 -452 -452 +333 452 580 580 580 452 -452 580 -580 -580 -452 -452 -580 -452 -452 -580 -452 -452 -452 -452 -452 452 452 452 583 452 -452 -452 -242 -452 -452 -452 -355 -580 -452 -452 -452 -580 -452 -580 -452 -580 -296 -452 -452 -452 -580 -452 -580 -452 -452 -580 -452 -580 -580 -580 -452 -242 -452 -580 -580 -580 -580 -580 -452 -452 -452 -452 -522 -580 -452 -452 -452 -452 -452 -580 -452 -452 -580 -580 -580 -580 -452 -580 -452 -452 -580 -452 -580 -452 -452 -452 -452 -452 -452 -580 -452 -580 -580 -580 -452 -580 -452 -452 -452 -580 -452 580 580 452 452 452 242 -242 +111 242 452 -242 +322 580 452 452 @@ -2321,23 +2321,23 @@ class_prediction 452 452 452 -12 +240 241 242 243 -167 -111 +244 +374 247 247 247 248 -425 +249 134 289 250 250 250 -597 +580 251 252 253 @@ -2346,70 +2346,70 @@ class_prediction 256 601 511 -445 +693 512 -442 -659 +689 +513 +515 515 -303 516 -661 +516 517 517 +518 517 -517 -520 -355 +519 +519 520 520 520 521 -596 +521 522 522 -596 523 -498 +523 +524 524 525 525 526 -84 +361 +529 529 -512 530 530 531 -547 -572 +531 +561 572 574 538 -43 -276 -383 +111 +539 +568 540 -547 +545 545 547 547 553 -609 +122 557 557 558 558 562 -591 +441 591 597 +198 593 593 -593 -370 -111 -306 -442 +595 +567 +595 +595 597 598 572 @@ -2417,24 +2417,24 @@ class_prediction 579 579 579 -22 -504 +653 +342 661 -516 -43 -520 -520 +661 +666 667 520 +668 +667 +675 +675 675 675 675 -134 -134 675 301 43 -84 +693 694 694 -512 +698 diff --git a/cosines_with_augmentations/results/output.txt b/cosines_with_augmentations/results/output.txt index fdc60dc..0fcd810 100644 --- a/cosines_with_augmentations/results/output.txt +++ b/cosines_with_augmentations/results/output.txt @@ -1,5 +1,5 @@ -Top-1 accuracy: 0.8257482574825749 -Top-3 accuracy: 0.9106191061910619 -Top-5 accuracy: 0.9261992619926199 -Top-10 accuracy: 0.942189421894219 -0.0 0.82121104 +Top-1 accuracy: 0.8072980729807298 +Top-3 accuracy: 0.8946289462894629 +Top-5 accuracy: 0.9040590405904059 +Top-10 accuracy: 0.924149241492415 +0.0 0.7571934 diff --git a/cosines_with_augmentations/study_sampling.py b/cosines_with_augmentations/study_sampling.py new file mode 100644 index 0000000..dd23da2 --- /dev/null +++ b/cosines_with_augmentations/study_sampling.py @@ -0,0 +1,270 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +# from loss import batch_all_triplet_loss, batch_hard_triplet_loss +import loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn.functional as F + + +# %% +SHUFFLES=3 +AMPLIFY_FACTOR=3 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +# %% +augmented_df = augment_data(df) +train_entity_id_mentions = make_entity_id_mentions(augmented_df) +train_entity_id_name = make_entity_id_name(augmented_df) + +data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) +random.shuffle(data) + + +# %% +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]) + yield x, y + + + +# simulate 1 epoch +y_accumulator = [] +augmented_df = augment_data(df) +train_entity_id_mentions = make_entity_id_mentions(augmented_df) +train_entity_id_name = make_entity_id_name(augmented_df) + +data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) +random.shuffle(data) +for x,y in batchGenerator(data, batch_size): + y_accumulator.append(y) + + +# %% +y_accumulator + +# %% diff --git a/cosines_with_augmentations/understanding_loss.py b/cosines_with_augmentations/understanding_loss.py new file mode 100644 index 0000000..e4b6092 --- /dev/null +++ b/cosines_with_augmentations/understanding_loss.py @@ -0,0 +1,378 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +# from loss import batch_all_triplet_loss, batch_hard_triplet_loss +import loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn.functional as F + + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +# %% +augmented_df = augment_data(df) +train_entity_id_mentions = make_entity_id_mentions(augmented_df) +train_entity_id_name = make_entity_id_name(augmented_df) + +data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) +random.shuffle(data) + + +# %% +x, y = next(iter(batchGenerator(data, batch_size))) + +# %% +inputs = tokenizer(x, padding=True, return_tensors='pt') +inputs.to(DEVICE) +outputs = model(**inputs) +cls = outputs.last_hidden_state[:,0,:] +# for training less than half the time, train on easy +y = torch.tensor(y).to(DEVICE) + +# %% +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + Args: + embeddings: tensor of shape (batch_size, embed_dim) + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + pairwise_distances: tensor of shape (batch_size, batch_size) + """ + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + + # Apply a lower bound to distances to ensure they are non-negative and avoid tiny negative numbers due to computation errors + distances = torch.clamp(distances, min=0.0) + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + epsilon = 1e-16 + mask = (distances < epsilon).float() + distances = distances + mask * epsilon + + distances = (1.0 - mask) * torch.sqrt(distances) + + return distances + +# %% +embeddings = cls +squared = False + +# %% + +# Get the pairwise distance matrix +pairwise_dist = loss._pairwise_distances(embeddings, squared=squared) # 96x96 + +anchor_positive_dist = pairwise_dist.unsqueeze(2) # 96x96x1 +anchor_negative_dist = pairwise_dist.unsqueeze(1) # 96x1x96 + +# Compute a 3D tensor of size (batch_size, batch_size, batch_size) +# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k +# every (i,j) pairwise distance - every (i,k) pairwise distance +# fixing for i, we get (i,j) - (i,k), for every j and k, which is 96x96 + +# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) +# and the 2nd (batch_size, 1, batch_size) +# remember that broadcasting is repeating the other axis n-times +# this broadcasting trick is to get every possible triple combination +triplet_loss = anchor_positive_dist - anchor_negative_dist + margin + +# triplet_loss 96x96x96 + +# %% +labels = y + +# %% + +# Put to zero the invalid triplets +# (where label(a) != label(p) or label(n) == label(a) or a == p) +mask = loss._get_triplet_mask(labels) +triplet_loss = mask.float() * triplet_loss + +# Remove negative losses (i.e. the easy triplets) +triplet_loss = F.relu(triplet_loss) + +# Count number of positive triplets (where triplet_loss > 0) +valid_triplets = triplet_loss[triplet_loss > 1e-16] +num_positive_triplets = valid_triplets.size(0) +num_valid_triplets = mask.sum() + +fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + +# Get final mean triplet loss over the positive valid triplets +triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + +# %% +# %% +loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + +# %% +loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + +# %% +# Check that i, j and k are distinct +# create an identity matrix of size 96 +indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + +# %% +indices_not_equal = ~indices_equal +i_not_equal_j = indices_not_equal.unsqueeze(2) # [96,96,1] +i_not_equal_k = indices_not_equal.unsqueeze(1) # [96,1,96] +j_not_equal_k = indices_not_equal.unsqueeze(0) # [1,96,96] + + +# %% +# eliminate any combination that uses the diagonal values (aka sharing same values) +distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k + + +# %% +label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) +# label_equal is a 96x96 matrix showing where 2 labels equate + +# perform the same unsqueeze to 1 and 2 axis and broadcast to get all possible combinations +# note that we have 96 elements, but we want all (i,j,k) combinations from these 96 elements +i_equal_j = label_equal.unsqueeze(2) +i_equal_k = label_equal.unsqueeze(1) + +# ~i_equal_k means that it checks for non-equality between i and k +# i_equal_j checks for equality between i and j +# we want (i,j) to be the same label, (i,k) to be different labels +valid_labels = ~i_equal_k & i_equal_j + +# %% +final_mask = distinct_indices & valid_labels + diff --git a/experimental/.gitignore b/experimental/.gitignore new file mode 100644 index 0000000..423e37c --- /dev/null +++ b/experimental/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +checkpoint +results +top1_curves \ No newline at end of file diff --git a/experimental/character_bert_train.py b/experimental/character_bert_train.py new file mode 100644 index 0000000..20f5e4f --- /dev/null +++ b/experimental/character_bert_train.py @@ -0,0 +1,577 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import BertTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import accuracy_score +from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + + + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +CORRUPT=0.00 +LEARNING_RATE=1e-6 +DEVICE = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +# %% +EVAL_FILE="top1_curves/batch_output.txt" +with open(EVAL_FILE, "w") as f: + pass + +EVAL_FILE_KNN="top1_curves/batch_knn.txt" +with open(EVAL_FILE_KNN, "w") as f: + pass + + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=CORRUPT) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + +# def sample_from_df(df, sample_size_per_class=5): +# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names +# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) +# .reset_index(drop=True)) +# +# return sampled_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + +# %% +# evaluation +def run_evaluation_logit(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(x_test, 64) + + pred_labels = [] + for batch in batches: + # Inference in batches + inputs, attn_mask = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + with torch.no_grad(): + _, logits = model(inputs, attn_mask) + predicted_class_ids = logits.argmax(dim=1).to("cpu") + pred_labels.extend(predicted_class_ids) + + + pred_labels = [tensor.item() for tensor in pred_labels] + + # %% + labels = [label2id[element] for element in y_test] + with open(EVAL_FILE, "a") as f: + # only compute top-1 + accuracy = accuracy_score(labels, pred_labels) + print(f'{accuracy}', file=f) + + + + + +def run_evaluation_knn(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(eval_train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs, attn_mask = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + outputs = model(inputs, attn_mask) + output_slice = outputs[:,0,:] + output_slice = output_slice.detach().cpu().numpy() + embedding_list.append(output_slice) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs, attn_mask = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + outputs = model(inputs, attn_mask) + output_slice = outputs[:,0,:] + output_slice = output_slice.detach().cpu().numpy() + embedding_list.append(output_slice) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels) + + + with open(EVAL_FILE_KNN, "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [eval_labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + +# %% +class CharacterTransformer(nn.Module): + def __init__(self, num_chars, d_model=256, nhead=4, num_encoder_layers=4): + super(CharacterTransformer, self).__init__() + self.char_embedding = nn.Embedding(num_chars, d_model) + encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers) + + def forward(self, input, attention_mask): + # input: (batch_size, seq_len) + embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model) + # embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model) + output = self.transformer_encoder(embeddings, src_key_padding_mask=attention_mask) + # output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model) + return output + +class ASCIITokenizer: + def __init__(self, pad_token='\0'): + # Initialize the tokenizer with ASCII characters. + # ASCII characters range from 0 to 127. + self.char_to_id = {chr(i): i for i in range(128)} + self.id_to_char = {i: chr(i) for i in range(128)} + self.pad_token = pad_token + + def encode(self, text_list): + """Encode a text string into a list of ASCII IDs and generate attention masks.""" + output_list = [] + max_length = 0 + # First pass to find the maximum length and encode the texts + for text in text_list: + text = self.pad_token + text # Prepend pad_token to each text + output = [self.char_to_id.get(char, self.pad_token) for char in text] + output_list.append(output) + if len(output) > max_length: + max_length = len(output) + + # Second pass to pad the sequences to the maximum length and create masks + padded_list = [] + attention_masks = [] + for output in output_list: + # we cannot mask the first token + attention_mask = [0] + [0] * (len(output) - 1) + [1] * (max_length - len(output)) # 1s for real tokens, 0s for padding + output = self.pad(output, max_length) + padded_list.append(output) + attention_masks.append(attention_mask) + + return torch.tensor(padded_list, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.bool) + + + def decode(self, ids_list): + """Decode a list of ASCII IDs back into a text string.""" + output_list = [] + for ids in ids_list: + output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char) + output_list.append(output) + return output_list + + def pad(self, output, max_length): + """Pad the output list with ASCII ID for space or another padding character to the maximum length.""" + return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output)) +# %% +tokenizer = ASCIITokenizer() +# # Example text +# text = ["Hello, world! This is cool", "Hello, world!"] +# # Encode the text +# encoded = tokenizer.encode(text) +# print("Encoded:", encoded) +# +# # Decode the encoded IDs +# decoded = tokenizer.decode(encoded.numpy()) +# print("Decoded:", decoded) + +# %% +# Example usage +bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids, attention_mask) + cls_embeddings = outputs[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# num_warmup_steps=100 +# total_steps = epochs * (1126/64) +# scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) +# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True) + + +# %% +state_dict = torch.load('./checkpoint/pretrained_character_bert.pt') +state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()} +model.load_state_dict(state_dict) +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + if epoch % 10 == 0: + augmented_df = augment_data(df) + # sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs, attn_mask = tokenizer.encode(x) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + cls, logits = model(inputs, attn_mask) + + # labels = y + # labels = [label2id[element] for element in labels] + # labels = torch.tensor(labels).to(DEVICE) + + # loss = F.cross_entropy(logits, labels) + + + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + # scheduler.step() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + # del x, y, outputs, cls, loss + # torch.cuda.empty_cache() + epoch_loss = total_loss/batch_number + + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {epoch_loss}') + if (epoch % 1 == 0): + model.eval() + with torch.no_grad(): + run_evaluation_logit(model=model, tokenizer=tokenizer) + run_evaluation_knn(model=model.bert, tokenizer=tokenizer) + # run evaluation on test data + model.train() + + + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if (epoch % 100 == 0) and (epoch > 100): + torch.save(model.state_dict(), './checkpoint/character_bert.pt') + + + +torch.save(model.state_dict(), './checkpoint/character_bert_final.pt') +# %% diff --git a/experimental/loss.py b/experimental/loss.py new file mode 100644 index 0000000..e226f48 --- /dev/null +++ b/experimental/loss.py @@ -0,0 +1,288 @@ +# stardard functionalities for computing triplet loss, borrow code from +# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py +import torch +import torch.nn.functional as F +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + Args: + embeddings: tensor of shape (batch_size, embed_dim) + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + pairwise_distances: tensor of shape (batch_size, batch_size) + """ + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + + # Because of computation errors, some distances might be negative so we put everything >= 0.0 + distances[distances < 0] = 0 + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + mask = distances.eq(0).float() + distances = distances + mask * 1e-16 + + distances = (1.0 -mask) * torch.sqrt(distances) + + return distances + +# def _pairwise_distances(embeddings, squared=False): +# embeddings = F.normalize(embeddings, p=2, dim=1) +# dot_product = torch.matmul(embeddings, embeddings.t()) +# cosine_distance = 1 - dot_product +# return cosine_distance + + + +def _get_triplet_mask(labels): + """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. + A triplet (i, j, k) is valid if: + - i, j, k are distinct + - labels[i] == labels[j] and labels[i] != labels[k] + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + """ + # Check that i, j and k are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + i_not_equal_j = indices_not_equal.unsqueeze(2) + i_not_equal_k = indices_not_equal.unsqueeze(1) + j_not_equal_k = indices_not_equal.unsqueeze(0) + + distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k + + + label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + i_equal_j = label_equal.unsqueeze(2) + i_equal_k = label_equal.unsqueeze(1) + + valid_labels = ~i_equal_k & i_equal_j + + return valid_labels & distinct_indices + + +def _get_anchor_positive_triplet_mask(labels): + """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check that i and j are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + + # Check if labels[i] == labels[j] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + + return labels_equal & indices_not_equal + + +def _get_anchor_negative_triplet_mask(labels): + """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check if labels[i] != labels[k] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + + return ~(labels.unsqueeze(0) == labels.unsqueeze(1)) + + +# Cell +def batch_hard_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + margin + tl = F.relu(tl) + triplet_loss = tl.mean() + + return triplet_loss + +# Cell +def batch_all_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + margin + + + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets + +def batch_all_soft_margin_triplet_loss(labels, embeddings, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + + # Apply exponential and log + triplet_loss = torch.log(1 + torch.exp(triplet_loss)) + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + # triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets + + + +def batch_hard_soft_margin_triplet_loss(labels, embeddings, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + # Apply exponential and log + triplet_loss = torch.log(1 + torch.exp(tl)) + + triplet_loss = triplet_loss.mean() + + return triplet_loss diff --git a/experimental/pretrain_character_bert_train.py b/experimental/pretrain_character_bert_train.py new file mode 100644 index 0000000..9cbcee9 --- /dev/null +++ b/experimental/pretrain_character_bert_train.py @@ -0,0 +1,574 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import BertTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import accuracy_score +from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + + + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +CORRUPT=0.1 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +# %% +EVAL_FILE="top1_curves/character_output.txt" +with open(EVAL_FILE, "w") as f: + pass + +EVAL_FILE_KNN="top1_curves/character_knn.txt" +with open(EVAL_FILE_KNN, "w") as f: + pass + + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=CORRUPT) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + +# def sample_from_df(df, sample_size_per_class=5): +# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names +# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) +# .reset_index(drop=True)) +# +# return sampled_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + +# %% +# evaluation +def run_evaluation_logit(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(x_test, 64) + + pred_labels = [] + for batch in batches: + # Inference in batches + inputs, attn_mask = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + with torch.no_grad(): + _, logits = model(inputs, attn_mask) + predicted_class_ids = logits.argmax(dim=1).to("cpu") + pred_labels.extend(predicted_class_ids) + + + pred_labels = [tensor.item() for tensor in pred_labels] + + # %% + labels = [label2id[element] for element in y_test] + with open(EVAL_FILE, "a") as f: + # only compute top-1 + accuracy = accuracy_score(labels, pred_labels) + print(f'{accuracy}', file=f) + + + + + +def run_evaluation_knn(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(eval_train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs, attn_mask = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + outputs = model(inputs, attn_mask) + output_slice = outputs[:,0,:] + output_slice = output_slice.detach().cpu().numpy() + embedding_list.append(output_slice) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs, attn_mask = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + outputs = model(inputs, attn_mask) + output_slice = outputs[:,0,:] + output_slice = output_slice.detach().cpu().numpy() + embedding_list.append(output_slice) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels) + + + with open(EVAL_FILE_KNN, "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [eval_labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + +# %% +class CharacterTransformer(nn.Module): + def __init__(self, num_chars, d_model=128, nhead=4, num_encoder_layers=2): + super(CharacterTransformer, self).__init__() + self.char_embedding = nn.Embedding(num_chars, d_model) + encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers) + + def forward(self, input, attention_mask): + # input: (batch_size, seq_len) + embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model) + # embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model) + output = self.transformer_encoder(embeddings, src_key_padding_mask=attention_mask) + # output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model) + return output + +# %% + +class ASCIITokenizer: + def __init__(self, pad_token='\0'): + # Initialize the tokenizer with ASCII characters. + # ASCII characters range from 0 to 127. + self.char_to_id = {chr(i): i for i in range(128)} + self.id_to_char = {i: chr(i) for i in range(128)} + self.pad_token = pad_token + + + def encode(self, text_list): + """Encode a text string into a list of ASCII IDs and generate attention masks.""" + output_list = [] + max_length = 0 + # First pass to find the maximum length and encode the texts + for text in text_list: + text = self.pad_token + text # Prepend pad_token to each text + output = [self.char_to_id.get(char, self.pad_token) for char in text] + output_list.append(output) + if len(output) > max_length: + max_length = len(output) + + # Second pass to pad the sequences to the maximum length and create masks + padded_list = [] + attention_masks = [] + for output in output_list: + # first element is not masked + attention_mask = [0] + [0] * (len(output) - 1) + [1] * (max_length - len(output)) # 1s for real tokens, 0s for padding + output = self.pad(output, max_length) + padded_list.append(output) + attention_masks.append(attention_mask) + + return torch.tensor(padded_list, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.bool) + + def decode(self, ids_list): + """Decode a list of ASCII IDs back into a text string.""" + output_list = [] + for ids in ids_list: + output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char) + output_list.append(output) + return output_list + + def pad(self, output, max_length): + """Pad the output list with ASCII ID for space or another padding character to the maximum length.""" + return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output)) +# %% +tokenizer = ASCIITokenizer() +# # Example text +# text = ["Hello, world! This is cool", "Hello, world!"] +# # Encode the text +# encoded = tokenizer.encode(text) +# print("Encoded:", encoded) +# +# # Decode the encoded IDs +# decoded = tokenizer.decode(encoded.numpy()) +# print("Decoded:", decoded) + +# %% +# Example usage +bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids, attention_mask) + cls_embeddings = outputs[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 32 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 5000 + +# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# num_warmup_steps=100 +# total_steps = epochs * (1126/64) +# scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) +# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True) + + +# %% +state_dict = torch.load('./checkpoint/pretrained_character_bert_final.pt') +state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()} +model.load_state_dict(state_dict) +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + if epoch % 1 == 0: + augmented_df = augment_data(df) + # sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs, attn_mask = tokenizer.encode(x) + inputs = inputs.to(DEVICE) + attn_mask = attn_mask.to(DEVICE) + cls, logits = model(inputs, attn_mask) + # labels = y + # labels = [label2id[element] for element in labels] + # labels = torch.tensor(labels).to(DEVICE) + + # loss = F.cross_entropy(logits, labels) + + + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + # loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + loss.backward() + # scheduler.step() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + # del x, y, outputs, cls, loss + # torch.cuda.empty_cache() + + epoch_loss = total_loss/batch_number + print(f'epoch loss: {epoch_loss}') + if (epoch % 10 == 0): + model.eval() + with torch.no_grad(): + # run_evaluation_logit(model=model, tokenizer=tokenizer) + run_evaluation_knn(model=model.bert, tokenizer=tokenizer) + # run evaluation on test data + model.train() + + + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if (epoch % 100 == 0) and (epoch > 100): + torch.save(model.state_dict(), './checkpoint/pretrained_character_bert.pt') + + + +torch.save(model.state_dict(), './checkpoint/pretrained_character_bert_final.pt') +# %% diff --git a/loss_comparisons_with_augmentations/.gitignore b/loss_comparisons_with_augmentations/.gitignore new file mode 100644 index 0000000..423e37c --- /dev/null +++ b/loss_comparisons_with_augmentations/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +checkpoint +results +top1_curves \ No newline at end of file diff --git a/loss_comparisons_with_augmentations/baseline_infer.py b/loss_comparisons_with_augmentations/baseline_infer.py new file mode 100644 index 0000000..cad2e79 --- /dev/null +++ b/loss_comparisons_with_augmentations/baseline_infer.py @@ -0,0 +1,131 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +# state_dict = torch.load('./checkpoint/classification.pt') +state_dict = torch.load('./checkpoint/baseline.pt') +# params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(state_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% +with open("results/predictions.txt", "w") as f: + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + print(f'{a}, {b[0]}', file=f) + diff --git a/loss_comparisons_with_augmentations/baseline_train.py b/loss_comparisons_with_augmentations/baseline_train.py new file mode 100644 index 0000000..d44e8e0 --- /dev/null +++ b/loss_comparisons_with_augmentations/baseline_train.py @@ -0,0 +1,400 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + + + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +CORRUPT=0.1 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +# %% +with open("top1_curves/baseline_output.txt", "w") as f: + pass + + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=CORRUPT) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + +# def sample_from_df(df, sample_size_per_class=5): +# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names +# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) +# .reset_index(drop=True)) +# +# return sampled_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + +# %% +# evaluation +def run_evaluation(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(eval_train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, eval_labels) + + + with open("top1_curves/baseline_output.txt", "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [eval_labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) +# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + if epoch % 1 == 0: + augmented_df = augment_data(df) + # sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + # del x, y, outputs, cls, loss + # torch.cuda.empty_cache() + epoch_loss = total_loss/batch_number + # scheduler.step(epoch_loss) + + # run evaluation on test data + model.eval() + with torch.no_grad(): + run_evaluation(model=model, tokenizer=tokenizer) + + model.train() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {epoch_loss}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch == 125: + torch.save(model.state_dict(), './checkpoint/baseline.pt') + + +# torch.save(model.state_dict(), './checkpoint/baseline.pt') +# %% diff --git a/loss_comparisons_with_augmentations/batch_all_train.py b/loss_comparisons_with_augmentations/batch_all_train.py new file mode 100644 index 0000000..54009e4 --- /dev/null +++ b/loss_comparisons_with_augmentations/batch_all_train.py @@ -0,0 +1,424 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import ( + batch_all_triplet_loss, + batch_hard_triplet_loss, + batch_all_soft_margin_triplet_loss, + batch_hard_soft_margin_triplet_loss) +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:3') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + + +# %% +EVAL_FILE="top1_curves/batch_all_output.txt" +with open(EVAL_FILE, "w") as f: + pass + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# evaluation +def run_evaluation(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) + all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) + train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} + train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] + y_test = [d['entity_id'] for _, d in test['data'].items()] + train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) + train_entities = [preprocess_text(element) for element in train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels) + + + with open(EVAL_FILE, "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + total_cross = 0.0 + total_triplet = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + + + + # if epoch < epochs / 2: + loss, _ = batch_all_soft_margin_triplet_loss(y, cls, squared=False) + # for training after half the time, train on hard + # else: + # triplet_loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False) + + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + # total_cross += class_loss.detach().item() + # total_triplet += triplet_loss.detach().item() + batch_number += 1 + + # run evaluation on test data + model.eval() + with torch.no_grad(): + run_evaluation(model=model.bert, tokenizer=tokenizer) + + model.train() + + + # scheduler.step() # Update the learning rate + # print(f'epoch loss: {total_loss/batch_number}, cross loss: {total_cross/batch_number}, triplet loss: {total_triplet/batch_number}') + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/batch_all.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/batch_all.pt') +# %% diff --git a/loss_comparisons_with_augmentations/character_bert_train.py b/loss_comparisons_with_augmentations/character_bert_train.py new file mode 100644 index 0000000..7c6ea10 --- /dev/null +++ b/loss_comparisons_with_augmentations/character_bert_train.py @@ -0,0 +1,561 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import BertTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import accuracy_score +from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + + + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +CORRUPT=0.1 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +# %% +EVAL_FILE="top1_curves/character_output.txt" +with open(EVAL_FILE, "w") as f: + pass + +EVAL_FILE_KNN="top1_curves/character_knn.txt" +with open(EVAL_FILE_KNN, "w") as f: + pass + + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=CORRUPT) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + +# def sample_from_df(df, sample_size_per_class=5): +# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names +# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) +# .reset_index(drop=True)) +# +# return sampled_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + +# %% +# evaluation +def run_evaluation_logit(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(x_test, 64) + + pred_labels = [] + for batch in batches: + # Inference in batches + inputs = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + with torch.no_grad(): + _, logits = model(inputs) + predicted_class_ids = logits.argmax(dim=1).to("cpu") + pred_labels.extend(predicted_class_ids) + + + pred_labels = [tensor.item() for tensor in pred_labels] + + # %% + labels = [label2id[element] for element in y_test] + with open(EVAL_FILE, "a") as f: + # only compute top-1 + accuracy = accuracy_score(labels, pred_labels) + print(f'{accuracy}', file=f) + + + + + +def run_evaluation_knn(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(eval_train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + outputs = model(inputs) + output_slice = outputs[:,0,:] + output_slice = output_slice.detach().cpu().numpy() + embedding_list.append(output_slice) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer.encode(batch) + inputs = inputs.to(DEVICE) + outputs = model(inputs) + output_slice = outputs[:,0,:] + output_slice = output_slice.detach().cpu().numpy() + embedding_list.append(output_slice) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels) + + + with open(EVAL_FILE_KNN, "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [eval_labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + +# %% +class CharacterTransformer(nn.Module): + def __init__(self, num_chars, d_model=512, nhead=8, num_encoder_layers=6): + super(CharacterTransformer, self).__init__() + self.char_embedding = nn.Embedding(num_chars, d_model) + encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers) + + def forward(self, input): + # input: (batch_size, seq_len) + embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model) + # embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model) + output = self.transformer_encoder(embeddings) + # output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model) + return output + +class ASCIITokenizer: + def __init__(self, pad_token='\0'): + # Initialize the tokenizer with ASCII characters. + # ASCII characters range from 0 to 127. + self.char_to_id = {chr(i): i for i in range(128)} + self.id_to_char = {i: chr(i) for i in range(128)} + self.pad_token = pad_token + + def encode(self, text_list): + """Encode a text string into a list of ASCII IDs.""" + output_list = [] + max_length = 0 + for text in text_list: + text = self.pad_token + text + output = [self.char_to_id.get(char, None) for char in text if char in self.char_to_id] + output_list.append(output) + if len(output) > max_length: + max_length = len(output) + padded_list = [self.pad(output, max_length) for output in output_list] + # Convert the list of lists into a tensor + return torch.tensor(padded_list, dtype=torch.long) + + def decode(self, ids_list): + """Decode a list of ASCII IDs back into a text string.""" + output_list = [] + for ids in ids_list: + output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char) + output_list.append(output) + return output_list + + def pad(self, output, max_length): + """Pad the output list with ASCII ID for space or another padding character to the maximum length.""" + return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output)) +# %% +tokenizer = ASCIITokenizer() +# # Example text +# text = ["Hello, world! This is cool", "Hello, world!"] +# # Encode the text +# encoded = tokenizer.encode(text) +# print("Encoded:", encoded) +# +# # Decode the encoded IDs +# decoded = tokenizer.decode(encoded.numpy()) +# print("Decoded:", decoded) + +# %% +# Example usage +bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids) + cls_embeddings = outputs[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 5000 + +# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +num_warmup_steps=100 +total_steps = epochs * (1126/64) +scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) +# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True) + + +# %% +# state_dict = torch.load('./checkpoint/character_bert.pt') +# state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()} +# model.load_state_dict(state_dict) +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + if epoch % 1 == 0: + augmented_df = augment_data(df) + # sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer.encode(x) + inputs = inputs.to(DEVICE) + cls, logits = model(inputs) + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + + loss = F.cross_entropy(logits, labels) + + + # for training less than half the time, train on easy + # y = torch.tensor(y).to(DEVICE) + # if epoch < epochs / 2: + # loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # # for training after half the time, train on hard + # else: + # loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + scheduler.step() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + # del x, y, outputs, cls, loss + # torch.cuda.empty_cache() + epoch_loss = total_loss/batch_number + + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {epoch_loss}') + if (epoch % 10 == 0): + model.eval() + with torch.no_grad(): + run_evaluation_logit(model=model, tokenizer=tokenizer) + run_evaluation_knn(model=model.bert, tokenizer=tokenizer) + # run evaluation on test data + model.train() + + + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if (epoch % 100 == 0) and (epoch > 100): + torch.save(model.state_dict(), './checkpoint/character_bert.pt') + + + +torch.save(model.state_dict(), './checkpoint/character_bert_final.pt') +# %% diff --git a/loss_comparisons_with_augmentations/classify_infer.py b/loss_comparisons_with_augmentations/classify_infer.py new file mode 100644 index 0000000..3109739 --- /dev/null +++ b/loss_comparisons_with_augmentations/classify_infer.py @@ -0,0 +1,124 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +state_dict = torch.load('./checkpoint/classification.pt') +params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(params_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/loss_comparisons_with_augmentations/classify_logits.py b/loss_comparisons_with_augmentations/classify_logits.py new file mode 100644 index 0000000..d09558b --- /dev/null +++ b/loss_comparisons_with_augmentations/classify_logits.py @@ -0,0 +1,258 @@ +# %% + +# from datasets import load_from_disk +import os +import glob + +os.environ['NCCL_P2P_DISABLE'] = '1' +os.environ['NCCL_IB_DISABLE'] = '1' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +import re +import torch +from torch.utils.data import DataLoader +import torch +import torch.nn as nn + +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + AutoModel, + DataCollatorWithPadding, +) +import evaluate +import numpy as np +import pandas as pd +# import matplotlib.pyplot as plt +from datasets import Dataset, DatasetDict + +from tqdm import tqdm + +torch.set_float32_matmul_precision('high') + + +BATCH_SIZE = 256 + +# %% +# construct the target id list +# data_path = '../../../esAppMod_data_import/train.csv' +data_path = '../esAppMod_data_import/train.csv' +train_df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = train_df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + + +# %% +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + + +# introduce pre-processing functions +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # Substitute digits with '#' + # text = re.sub(r'\d+', '#', text) + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + + +# outputs a list of dictionaries +# processes dataframe into lists of dictionaries +# each element maps input to output +# input: tag_description +# output: class label +def process_df_to_dict(df): + output_list = [] + for _, row in df.iterrows(): + desc = row['mention'] + desc = preprocess_text(desc) + index = row['entity_id'] + element = { + 'text' : desc, + 'label': label2id[index], # ensure labels starts from 0 + } + output_list.append(element) + + return output_list + + +def create_dataset(): + # train + data_path = '../esAppMod_data_import/test.csv' + test_df = pd.read_csv(data_path, skipinitialspace=True) + + + # combined_data = DatasetDict({ + # 'train': Dataset.from_list(process_df_to_dict(train_df)), + # }) + return Dataset.from_list(process_df_to_dict(test_df)) + + + +# %% + +def test(): + + test_dataset = create_dataset() + + # prepare tokenizer + + # MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + # MODEL_NAME = 'distilbert-base-cased' + MODEL_NAME = 'prajjwal1/bert-small' + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True) + # Define additional special tokens + # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] + # Add the additional special tokens to the tokenizer + # tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + + # %% + # compute max token length + max_length = 0 + for sample in test_dataset['text']: + # Tokenize the sample and get the length + input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"] + length = len(input_ids) + + # Update max_length if this sample is longer + if length > max_length: + max_length = length + + print(max_length) + + # %% + + max_length = 128 + + # given a dataset entry, run it through the tokenizer + def preprocess_function(example): + input = example['text'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + input, + # max_length=max_length, + # padding='max_length' + ) + return model_inputs + + # map maps function to each "row" in the dataset + # aka the data in the immediate nesting + datasets = test_dataset.map( + preprocess_function, + batched=True, + num_proc=8, + remove_columns="text", + ) + + + datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) + + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + bert_model = AutoModel.from_pretrained(MODEL_NAME) + + class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + + model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + state_dict = torch.load('./checkpoint/classification.pt') + model.load_state_dict(state_dict) + + model = model.eval() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + + pred_labels = [] + actual_labels = [] + + + dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator) + for batch in tqdm(dataloader): + # Inference in batches + input_ids = batch['input_ids'] + attention_mask = batch['attention_mask'] + # save labels too + actual_labels.extend(batch['labels']) + + + # Move to GPU if available + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + # Perform inference + with torch.no_grad(): + cls, logits = model( + input_ids, + attention_mask) + predicted_class_ids = logits.argmax(dim=1).to("cpu") + pred_labels.extend(predicted_class_ids) + + pred_labels = [tensor.item() for tensor in pred_labels] + + + # %% + from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix + y_true = actual_labels + y_pred = pred_labels + + # Compute metrics + accuracy = accuracy_score(y_true, y_pred) + average_parameter = 'weighted' + zero_division_parameter = 0 + f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + + with open("results/output.txt", "a") as f: + + print('*' * 80, file=f) + # Print the results + print(f'Accuracy: {accuracy:.5f}', file=f) + print(f'F1 Score: {f1:.5f}', file=f) + print(f'Precision: {precision:.5f}', file=f) + print(f'Recall: {recall:.5f}', file=f) + + # export result + label_list = [id2label[id] for id in pred_labels] + df = pd.DataFrame({ + 'class_prediction': pd.Series(label_list) + }) + + # we can save the t5 generation output here + df.to_csv(f"results/classify.csv", index=False) + + + + + + +# %% +# reset file before writing to it +with open("results/output.txt", "w") as f: + print('', file=f) + test() diff --git a/loss_comparisons_with_augmentations/classify_train.py b/loss_comparisons_with_augmentations/classify_train.py new file mode 100644 index 0000000..de5a63d --- /dev/null +++ b/loss_comparisons_with_augmentations/classify_train.py @@ -0,0 +1,315 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + # y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + # if epoch < epochs / 2: + # triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # # for training after half the time, train on hard + # else: + # triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss # + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/classification.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/classification.pt') +# %% diff --git a/loss_comparisons_with_augmentations/esAppMod_train.py b/loss_comparisons_with_augmentations/esAppMod_train.py new file mode 100644 index 0000000..535a5c2 --- /dev/null +++ b/loss_comparisons_with_augmentations/esAppMod_train.py @@ -0,0 +1,277 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim + + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # all augmentations disabled + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, outputs, cls, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch % 5 == 0: + torch.save(model.state_dict(), './checkpoint/baseline.pt') + + +torch.save(model.state_dict(), './checkpoint/baseline.pt') +# %% diff --git a/loss_comparisons_with_augmentations/esAppMod_train_with_classification.py b/loss_comparisons_with_augmentations/esAppMod_train_with_classification.py new file mode 100644 index 0000000..97864ef --- /dev/null +++ b/loss_comparisons_with_augmentations/esAppMod_train_with_classification.py @@ -0,0 +1,315 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=2 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # Update optimizer's learning rate + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + if epoch < epochs / 2: + triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/classification.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/classification.pt') +# %% diff --git a/loss_comparisons_with_augmentations/hybrid_infer.py b/loss_comparisons_with_augmentations/hybrid_infer.py new file mode 100644 index 0000000..94d69b5 --- /dev/null +++ b/loss_comparisons_with_augmentations/hybrid_infer.py @@ -0,0 +1,124 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +state_dict = torch.load('./checkpoint/hybrid.pt') +params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(params_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/loss_comparisons_with_augmentations/hybrid_train.py b/loss_comparisons_with_augmentations/hybrid_train.py new file mode 100644 index 0000000..9894232 --- /dev/null +++ b/loss_comparisons_with_augmentations/hybrid_train.py @@ -0,0 +1,433 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import ( + batch_all_triplet_loss, + batch_hard_triplet_loss, + batch_all_soft_margin_triplet_loss, + batch_hard_soft_margin_triplet_loss) +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +LEARNING_RATE=1e-4 +DEVICE = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu') + + +# %% +EVAL_FILE="top1_curves/hybrid_output.txt" +with open(EVAL_FILE, "w") as f: + pass + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', '_', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# evaluation +def run_evaluation(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) + all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) + train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} + train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] + y_test = [d['entity_id'] for _, d in test['data'].items()] + train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) + train_entities = [preprocess_text(element) for element in train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels) + + + with open(EVAL_FILE, "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + total_cross = 0.0 + total_triplet = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + if epoch % 10 == 0: + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + if epoch < epochs / 2: + # triplet_loss, _ = batch_all_soft_margin_triplet_loss(y, cls, squared=False) + # loss = class_loss + triplet_loss + # loss,_ = batch_all_soft_margin_triplet_loss(y, cls, squared=False) + loss = class_loss + # for training after half the time, train on hard + # else: + # triplet_loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False) + # loss = triplet_loss + else: + loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False) + + + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + # total_cross += class_loss.detach().item() + # total_triplet += triplet_loss.detach().item() + batch_number += 1 + + # run evaluation on test data + model.eval() + with torch.no_grad(): + run_evaluation(model=model.bert, tokenizer=tokenizer) + + model.train() + + + # scheduler.step() # Update the learning rate + # print(f'epoch loss: {total_loss/batch_number}, cross loss: {total_cross/batch_number}, triplet loss: {total_triplet/batch_number}') + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + # if epoch % 5 == 0: + # # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + # torch.save(model.state_dict(), './checkpoint/hybrid.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +# torch.save(model.state_dict(), './checkpoint/hybrid.pt') +# %% diff --git a/loss_comparisons_with_augmentations/loss.py b/loss_comparisons_with_augmentations/loss.py new file mode 100644 index 0000000..e226f48 --- /dev/null +++ b/loss_comparisons_with_augmentations/loss.py @@ -0,0 +1,288 @@ +# stardard functionalities for computing triplet loss, borrow code from +# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py +import torch +import torch.nn.functional as F +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + Args: + embeddings: tensor of shape (batch_size, embed_dim) + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + pairwise_distances: tensor of shape (batch_size, batch_size) + """ + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + + # Because of computation errors, some distances might be negative so we put everything >= 0.0 + distances[distances < 0] = 0 + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + mask = distances.eq(0).float() + distances = distances + mask * 1e-16 + + distances = (1.0 -mask) * torch.sqrt(distances) + + return distances + +# def _pairwise_distances(embeddings, squared=False): +# embeddings = F.normalize(embeddings, p=2, dim=1) +# dot_product = torch.matmul(embeddings, embeddings.t()) +# cosine_distance = 1 - dot_product +# return cosine_distance + + + +def _get_triplet_mask(labels): + """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. + A triplet (i, j, k) is valid if: + - i, j, k are distinct + - labels[i] == labels[j] and labels[i] != labels[k] + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + """ + # Check that i, j and k are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + i_not_equal_j = indices_not_equal.unsqueeze(2) + i_not_equal_k = indices_not_equal.unsqueeze(1) + j_not_equal_k = indices_not_equal.unsqueeze(0) + + distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k + + + label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + i_equal_j = label_equal.unsqueeze(2) + i_equal_k = label_equal.unsqueeze(1) + + valid_labels = ~i_equal_k & i_equal_j + + return valid_labels & distinct_indices + + +def _get_anchor_positive_triplet_mask(labels): + """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check that i and j are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + + # Check if labels[i] == labels[j] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + + return labels_equal & indices_not_equal + + +def _get_anchor_negative_triplet_mask(labels): + """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check if labels[i] != labels[k] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + + return ~(labels.unsqueeze(0) == labels.unsqueeze(1)) + + +# Cell +def batch_hard_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + margin + tl = F.relu(tl) + triplet_loss = tl.mean() + + return triplet_loss + +# Cell +def batch_all_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + margin + + + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets + +def batch_all_soft_margin_triplet_loss(labels, embeddings, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + + # Apply exponential and log + triplet_loss = torch.log(1 + torch.exp(triplet_loss)) + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + # triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets + + + +def batch_hard_soft_margin_triplet_loss(labels, embeddings, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + # Apply exponential and log + triplet_loss = torch.log(1 + torch.exp(tl)) + + triplet_loss = triplet_loss.mean() + + return triplet_loss diff --git a/loss_comparisons_without_augmentation/.gitignore b/loss_comparisons_without_augmentation/.gitignore new file mode 100644 index 0000000..423e37c --- /dev/null +++ b/loss_comparisons_without_augmentation/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +checkpoint +results +top1_curves \ No newline at end of file diff --git a/loss_comparisons_without_augmentation/baseline_infer.py b/loss_comparisons_without_augmentation/baseline_infer.py new file mode 100644 index 0000000..f122c48 --- /dev/null +++ b/loss_comparisons_without_augmentation/baseline_infer.py @@ -0,0 +1,132 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +# state_dict = torch.load('./checkpoint/classification.pt') +state_dict = torch.load('./checkpoint/baseline.pt') +# params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(state_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +with open("results/predictions.txt", "w") as f: + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + print(f'{a}, {b[0]}', file=f) + + +# %% diff --git a/loss_comparisons_without_augmentation/baseline_train.py b/loss_comparisons_without_augmentation/baseline_train.py new file mode 100644 index 0000000..3c94d00 --- /dev/null +++ b/loss_comparisons_without_augmentation/baseline_train.py @@ -0,0 +1,382 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +# %% +EVAL_FILE="top1_curves/baseline_output.txt" +with open(EVAL_FILE, "w") as f: + pass + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # all augmentations disabled + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + +# evaluation +def run_evaluation(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) + all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) + train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} + train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] + y_test = [d['entity_id'] for _, d in test['data'].items()] + train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) + train_entities = [preprocess_text(element) for element in train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels) + + + with open(EVAL_FILE, "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + # run evaluation on test data + model.eval() + with torch.no_grad(): + run_evaluation(model=model, tokenizer=tokenizer) + + model.train() + + + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch == 175: + torch.save(model.state_dict(), './checkpoint/baseline.pt') + + +# torch.save(model.state_dict(), './checkpoint/baseline.pt') +# %% diff --git a/loss_comparisons_without_augmentation/classify_infer.py b/loss_comparisons_without_augmentation/classify_infer.py new file mode 100644 index 0000000..3109739 --- /dev/null +++ b/loss_comparisons_without_augmentation/classify_infer.py @@ -0,0 +1,124 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +state_dict = torch.load('./checkpoint/classification.pt') +params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(params_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/loss_comparisons_without_augmentation/classify_logits.py b/loss_comparisons_without_augmentation/classify_logits.py new file mode 100644 index 0000000..d09558b --- /dev/null +++ b/loss_comparisons_without_augmentation/classify_logits.py @@ -0,0 +1,258 @@ +# %% + +# from datasets import load_from_disk +import os +import glob + +os.environ['NCCL_P2P_DISABLE'] = '1' +os.environ['NCCL_IB_DISABLE'] = '1' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +import re +import torch +from torch.utils.data import DataLoader +import torch +import torch.nn as nn + +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + AutoModel, + DataCollatorWithPadding, +) +import evaluate +import numpy as np +import pandas as pd +# import matplotlib.pyplot as plt +from datasets import Dataset, DatasetDict + +from tqdm import tqdm + +torch.set_float32_matmul_precision('high') + + +BATCH_SIZE = 256 + +# %% +# construct the target id list +# data_path = '../../../esAppMod_data_import/train.csv' +data_path = '../esAppMod_data_import/train.csv' +train_df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = train_df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + + +# %% +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + + +# introduce pre-processing functions +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # Substitute digits with '#' + # text = re.sub(r'\d+', '#', text) + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + + +# outputs a list of dictionaries +# processes dataframe into lists of dictionaries +# each element maps input to output +# input: tag_description +# output: class label +def process_df_to_dict(df): + output_list = [] + for _, row in df.iterrows(): + desc = row['mention'] + desc = preprocess_text(desc) + index = row['entity_id'] + element = { + 'text' : desc, + 'label': label2id[index], # ensure labels starts from 0 + } + output_list.append(element) + + return output_list + + +def create_dataset(): + # train + data_path = '../esAppMod_data_import/test.csv' + test_df = pd.read_csv(data_path, skipinitialspace=True) + + + # combined_data = DatasetDict({ + # 'train': Dataset.from_list(process_df_to_dict(train_df)), + # }) + return Dataset.from_list(process_df_to_dict(test_df)) + + + +# %% + +def test(): + + test_dataset = create_dataset() + + # prepare tokenizer + + # MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + # MODEL_NAME = 'distilbert-base-cased' + MODEL_NAME = 'prajjwal1/bert-small' + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True) + # Define additional special tokens + # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] + # Add the additional special tokens to the tokenizer + # tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + + # %% + # compute max token length + max_length = 0 + for sample in test_dataset['text']: + # Tokenize the sample and get the length + input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"] + length = len(input_ids) + + # Update max_length if this sample is longer + if length > max_length: + max_length = length + + print(max_length) + + # %% + + max_length = 128 + + # given a dataset entry, run it through the tokenizer + def preprocess_function(example): + input = example['text'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + input, + # max_length=max_length, + # padding='max_length' + ) + return model_inputs + + # map maps function to each "row" in the dataset + # aka the data in the immediate nesting + datasets = test_dataset.map( + preprocess_function, + batched=True, + num_proc=8, + remove_columns="text", + ) + + + datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) + + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + bert_model = AutoModel.from_pretrained(MODEL_NAME) + + class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + + model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + state_dict = torch.load('./checkpoint/classification.pt') + model.load_state_dict(state_dict) + + model = model.eval() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + + pred_labels = [] + actual_labels = [] + + + dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator) + for batch in tqdm(dataloader): + # Inference in batches + input_ids = batch['input_ids'] + attention_mask = batch['attention_mask'] + # save labels too + actual_labels.extend(batch['labels']) + + + # Move to GPU if available + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + # Perform inference + with torch.no_grad(): + cls, logits = model( + input_ids, + attention_mask) + predicted_class_ids = logits.argmax(dim=1).to("cpu") + pred_labels.extend(predicted_class_ids) + + pred_labels = [tensor.item() for tensor in pred_labels] + + + # %% + from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix + y_true = actual_labels + y_pred = pred_labels + + # Compute metrics + accuracy = accuracy_score(y_true, y_pred) + average_parameter = 'weighted' + zero_division_parameter = 0 + f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + + with open("results/output.txt", "a") as f: + + print('*' * 80, file=f) + # Print the results + print(f'Accuracy: {accuracy:.5f}', file=f) + print(f'F1 Score: {f1:.5f}', file=f) + print(f'Precision: {precision:.5f}', file=f) + print(f'Recall: {recall:.5f}', file=f) + + # export result + label_list = [id2label[id] for id in pred_labels] + df = pd.DataFrame({ + 'class_prediction': pd.Series(label_list) + }) + + # we can save the t5 generation output here + df.to_csv(f"results/classify.csv", index=False) + + + + + + +# %% +# reset file before writing to it +with open("results/output.txt", "w") as f: + print('', file=f) + test() diff --git a/loss_comparisons_without_augmentation/classify_train.py b/loss_comparisons_without_augmentation/classify_train.py new file mode 100644 index 0000000..77ddb70 --- /dev/null +++ b/loss_comparisons_without_augmentation/classify_train.py @@ -0,0 +1,316 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # disable augmentations + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + # y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + # if epoch < epochs / 2: + # triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # # for training after half the time, train on hard + # else: + # triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss # + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/classification.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/classification.pt') +# %% diff --git a/loss_comparisons_without_augmentation/esAppMod_train.py b/loss_comparisons_without_augmentation/esAppMod_train.py new file mode 100644 index 0000000..535a5c2 --- /dev/null +++ b/loss_comparisons_without_augmentation/esAppMod_train.py @@ -0,0 +1,277 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim + + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # all augmentations disabled + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, outputs, cls, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch % 5 == 0: + torch.save(model.state_dict(), './checkpoint/baseline.pt') + + +torch.save(model.state_dict(), './checkpoint/baseline.pt') +# %% diff --git a/loss_comparisons_without_augmentation/esAppMod_train_with_classification.py b/loss_comparisons_without_augmentation/esAppMod_train_with_classification.py new file mode 100644 index 0000000..97864ef --- /dev/null +++ b/loss_comparisons_without_augmentation/esAppMod_train_with_classification.py @@ -0,0 +1,315 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=2 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # Update optimizer's learning rate + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + if epoch < epochs / 2: + triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/classification.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/classification.pt') +# %% diff --git a/loss_comparisons_without_augmentation/hybrid_infer.py b/loss_comparisons_without_augmentation/hybrid_infer.py new file mode 100644 index 0000000..94d69b5 --- /dev/null +++ b/loss_comparisons_without_augmentation/hybrid_infer.py @@ -0,0 +1,124 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +state_dict = torch.load('./checkpoint/hybrid.pt') +params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(params_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/loss_comparisons_without_augmentation/hybrid_train.py b/loss_comparisons_without_augmentation/hybrid_train.py new file mode 100644 index 0000000..74e8aa2 --- /dev/null +++ b/loss_comparisons_without_augmentation/hybrid_train.py @@ -0,0 +1,315 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add corrupted strings + # desc = corrupt_string(parent_desc, corruption_probability=0.01) + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # add example with stripped non-alphanumerics + # desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + +def linear_decay(epoch, max_epochs, initial_lr, final_lr): + """ Calculate the linearly decayed learning rate. """ + return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) + + # # Update optimizer's learning rate + # for param_group in optimizer.param_groups: + # param_group['lr'] = lr + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + if epoch < epochs / 2: + triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={lr}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/hybrid.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/hybrid.pt') +# %% diff --git a/loss_comparisons_without_augmentation/loss.py b/loss_comparisons_without_augmentation/loss.py new file mode 100644 index 0000000..c584e3f --- /dev/null +++ b/loss_comparisons_without_augmentation/loss.py @@ -0,0 +1,193 @@ +# stardard functionalities for computing triplet loss, borrow code from +# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py +import torch +import torch.nn.functional as F + + +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + Args: + embeddings: tensor of shape (batch_size, embed_dim) + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + pairwise_distances: tensor of shape (batch_size, batch_size) + """ + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + + # Apply a lower bound to distances to ensure they are non-negative and avoid tiny negative numbers due to computation errors + distances = torch.clamp(distances, min=0.0) + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + epsilon = 1e-16 + mask = (distances < epsilon).float() + distances = distances + mask * epsilon + + distances = (1.0 - mask) * torch.sqrt(distances) + + + return distances + +def _get_triplet_mask(labels): + """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. + A triplet (i, j, k) is valid if: + - i, j, k are distinct + - labels[i] == labels[j] and labels[i] != labels[k] + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + """ + # Check that i, j and k are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + i_not_equal_j = indices_not_equal.unsqueeze(2) + i_not_equal_k = indices_not_equal.unsqueeze(1) + j_not_equal_k = indices_not_equal.unsqueeze(0) + + # ensures that none of the values use diagonal values (where at least 2 values are the same) + distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k + + + label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + i_equal_j = label_equal.unsqueeze(2) + i_equal_k = label_equal.unsqueeze(1) + + # valid triplets are (i,j) sharing same label and + # (i,k) having different labels + valid_labels = ~i_equal_k & i_equal_j + + return valid_labels & distinct_indices + + +def _get_anchor_positive_triplet_mask(labels): + """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check that i and j are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + + # Check if labels[i] == labels[j] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + + return labels_equal & indices_not_equal + + +def _get_anchor_negative_triplet_mask(labels): + """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check if labels[i] != labels[k] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + + return ~(labels.unsqueeze(0) == labels.unsqueeze(1)) + + +# Cell +def batch_hard_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + margin + tl = F.relu(tl) + triplet_loss = tl.mean() + + return triplet_loss + +# Cell +def batch_all_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + margin + + + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets \ No newline at end of file diff --git a/reference_code/character_bert_train.py b/reference_code/character_bert_train.py new file mode 100644 index 0000000..123393b --- /dev/null +++ b/reference_code/character_bert_train.py @@ -0,0 +1,460 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import BertTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + + + + + +# %% +SHUFFLES=1 +AMPLIFY_FACTOR=1 +CORRUPT=0.1 +LEARNING_RATE=1e-5 +DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +MODEL_NAME = 'helboukkouri/character-bert' + +# %% +with open("top1_curves/character_output.txt", "w") as f: + pass + + + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # # add shuffled strings + # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + # for desc in processed_descs: + # if (desc != parent_desc): + # output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=CORRUPT) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # # short sequence amplifier + # # short sequences are rare, and we must compensate by including more examples + # # also, short sequence don't usually get affected by shuffle + # words = parent_desc.split() + # word_count = len(words) + # if word_count <= 2: + # for _ in range(AMPLIFY_FACTOR): + # output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + +# def sample_from_df(df, sample_size_per_class=5): +# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names +# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) +# .reset_index(drop=True)) +# +# return sampled_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + +# %% +# evaluation +def run_evaluation(model, tokenizer): + def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + with open('../esAppMod/tca_entities.json', 'r') as file: + eval_entities = json.load(file) + eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()} + + with open('../esAppMod/train.json', 'r') as file: + eval_train = json.load(file) + eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()} + eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()} + + with open('../esAppMod/infer.json', 'r') as file: + eval_test = json.load(file) + x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()] + y_test = [d['entity_id'] for _, d in eval_test['data'].items()] + eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys()) + eval_train_entities = [preprocess_text(element) for element in eval_train_entities] + + def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + batches = batch_list(eval_train_entities, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls = np.concatenate(embedding_list) + + batches = batch_list(x_test, 64) + + embedding_list = [] + for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + + cls_test = np.concatenate(embedding_list) + + knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, eval_labels) + + + with open("top1_curves/baseline_output.txt", "a") as f: + # only compute top-1 + distances, indices = knn.kneighbors(cls_test, n_neighbors=1) + num = 0 + for a,b in zip(y_test, indices): + b = [eval_labels[i] for i in b] + if a in b: + num += 1 + print(f'{num / len(y_test)}', file=f) + + +# %% +class CharacterTransformer(nn.Module): + def __init__(self, num_chars, d_model=512, nhead=8, num_encoder_layers=6): + super(CharacterTransformer, self).__init__() + self.char_embedding = nn.Embedding(num_chars, d_model) + encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers) + + def forward(self, input): + # input: (batch_size, seq_len) + embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model) + # embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model) + output = self.transformer_encoder(embeddings) + # output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model) + return output + +class ASCIITokenizer: + def __init__(self): + # Initialize the tokenizer with ASCII characters. + # ASCII characters range from 0 to 127. + self.char_to_id = {chr(i): i for i in range(128)} + self.id_to_char = {i: chr(i) for i in range(128)} + + def encode(self, text_list): + """Encode a text string into a list of ASCII IDs.""" + output_list = [] + for text in text_list: + output = [self.char_to_id.get(char, None) for char in text if char in self.char_to_id] + output_list.append(output) + return output_list + + def decode(self, ids_list): + """Decode a list of ASCII IDs back into a text string.""" + output_list = [] + for ids in ids_list: + output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char) + output_list.append(output) + return output_list + +# %% +tokenizer = ASCIITokenizer() +# Example text +text = ["Hello, world!", "Hello, world!"] +# Encode the text +encoded = tokenizer.encode(text) +print("Encoded:", encoded) + +# Decode the encoded IDs +decoded = tokenizer.decode(encoded) +print("Decoded:", decoded) + +# %% +# Example usage +model = CharacterTransformer(num_chars=128) # Assuming ASCII characters +input = torch.randint(0, 128, (10, 50)) # Example input tensor 10 sequences of 50 characters +output = model(input) +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 + +# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) +# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) +# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True) + + + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + if epoch % 1 == 0: + augmented_df = augment_data(df) + # sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + # del x, y, outputs, cls, loss + # torch.cuda.empty_cache() + epoch_loss = total_loss/batch_number + # scheduler.step(epoch_loss) + + # run evaluation on test data + model.eval() + with torch.no_grad(): + run_evaluation(model=model, tokenizer=tokenizer) + + model.train() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {epoch_loss}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + # if epoch == 125: + # torch.save(model.state_dict(), './checkpoint/baseline.pt') + + +# torch.save(model.state_dict(), './checkpoint/baseline.pt') +# %%