diff --git a/cosines_with_augmentations/.gitignore b/cosines_with_augmentations/.gitignore new file mode 100644 index 0000000..fd7e5dc --- /dev/null +++ b/cosines_with_augmentations/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +checkpoint \ No newline at end of file diff --git a/cosines_with_augmentations/classify.py b/cosines_with_augmentations/classify.py new file mode 100644 index 0000000..889aa81 --- /dev/null +++ b/cosines_with_augmentations/classify.py @@ -0,0 +1,256 @@ +# %% + +# from datasets import load_from_disk +import os +import glob + +os.environ['NCCL_P2P_DISABLE'] = '1' +os.environ['NCCL_IB_DISABLE'] = '1' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +import re +import torch +from torch.utils.data import DataLoader +import torch +import torch.nn as nn + +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + AutoModel, + DataCollatorWithPadding, +) +import evaluate +import numpy as np +import pandas as pd +# import matplotlib.pyplot as plt +from datasets import Dataset, DatasetDict + +from tqdm import tqdm + +torch.set_float32_matmul_precision('high') + + +BATCH_SIZE = 256 + +# %% +# construct the target id list +# data_path = '../../../esAppMod_data_import/train.csv' +data_path = '../esAppMod_data_import/train.csv' +train_df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = train_df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + + +# %% +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + + +# introduce pre-processing functions +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # Substitute digits with '#' + # text = re.sub(r'\d+', '#', text) + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + + +# outputs a list of dictionaries +# processes dataframe into lists of dictionaries +# each element maps input to output +# input: tag_description +# output: class label +def process_df_to_dict(df): + output_list = [] + for _, row in df.iterrows(): + desc = row['mention'] + desc = preprocess_text(desc) + index = row['entity_id'] + element = { + 'text' : desc, + 'label': label2id[index], # ensure labels starts from 0 + } + output_list.append(element) + + return output_list + + +def create_dataset(): + # train + data_path = '../esAppMod_data_import/test.csv' + test_df = pd.read_csv(data_path, skipinitialspace=True) + + + # combined_data = DatasetDict({ + # 'train': Dataset.from_list(process_df_to_dict(train_df)), + # }) + return Dataset.from_list(process_df_to_dict(test_df)) + + + +# %% + +def test(): + + test_dataset = create_dataset() + + # prepare tokenizer + + MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True) + # Define additional special tokens + # additional_special_tokens = ["", "", "", "", "", "", "", "", ""] + # Add the additional special tokens to the tokenizer + # tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + + # %% + # compute max token length + max_length = 0 + for sample in test_dataset['text']: + # Tokenize the sample and get the length + input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"] + length = len(input_ids) + + # Update max_length if this sample is longer + if length > max_length: + max_length = length + + print(max_length) + + # %% + + max_length = 128 + + # given a dataset entry, run it through the tokenizer + def preprocess_function(example): + input = example['text'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + input, + # max_length=max_length, + # padding='max_length' + ) + return model_inputs + + # map maps function to each "row" in the dataset + # aka the data in the immediate nesting + datasets = test_dataset.map( + preprocess_function, + batched=True, + num_proc=8, + remove_columns="text", + ) + + + datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) + + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + bert_model = AutoModel.from_pretrained(MODEL_NAME) + + class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + + model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + state_dict = torch.load('./checkpoint/classification.pt') + model.load_state_dict(state_dict) + + model = model.eval() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + + pred_labels = [] + actual_labels = [] + + + dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator) + for batch in tqdm(dataloader): + # Inference in batches + input_ids = batch['input_ids'] + attention_mask = batch['attention_mask'] + # save labels too + actual_labels.extend(batch['labels']) + + + # Move to GPU if available + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + # Perform inference + with torch.no_grad(): + cls, logits = model( + input_ids, + attention_mask) + predicted_class_ids = logits.argmax(dim=1).to("cpu") + pred_labels.extend(predicted_class_ids) + + pred_labels = [tensor.item() for tensor in pred_labels] + + + # %% + from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix + y_true = actual_labels + y_pred = pred_labels + + # Compute metrics + accuracy = accuracy_score(y_true, y_pred) + average_parameter = 'weighted' + zero_division_parameter = 0 + f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter) + + with open("results/output.txt", "a") as f: + + print('*' * 80, file=f) + # Print the results + print(f'Accuracy: {accuracy:.5f}', file=f) + print(f'F1 Score: {f1:.5f}', file=f) + print(f'Precision: {precision:.5f}', file=f) + print(f'Recall: {recall:.5f}', file=f) + + # export result + label_list = [id2label[id] for id in pred_labels] + df = pd.DataFrame({ + 'class_prediction': pd.Series(label_list) + }) + + # we can save the t5 generation output here + df.to_csv(f"results/classify.csv", index=False) + + + + + + +# %% +# reset file before writing to it +with open("results/output.txt", "w") as f: + print('', file=f) + test() diff --git a/cosines_with_augmentations/esAppMod_infer.py b/cosines_with_augmentations/esAppMod_infer.py new file mode 100644 index 0000000..0d8d60e --- /dev/null +++ b/cosines_with_augmentations/esAppMod_infer.py @@ -0,0 +1,124 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import re +import gc + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) + +# state_dict = torch.load('./checkpoint/siamese.pt') +# state_dict = torch.load('./checkpoint/siamese_simple.pt') +state_dict = torch.load('./checkpoint/classification.pt') +params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name} + +# %% +# Step 3: Apply the state dictionary to the model +model.load_state_dict(params_dict) +model.to(DEVICE) +model.eval() + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + + +# %% +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls = np.concatenate(embedding_list) +# %% +gc.collect() +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + output = outputs.last_hidden_state[:,0,:] + output = output.detach().cpu().numpy() + embedding_list.append(output) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + + +with open("results/output.txt", "w") as f: + for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) + print(np.min(distances), np.max(distances), file=f) + +# %% diff --git a/cosines_with_augmentations/esAppMod_train.py b/cosines_with_augmentations/esAppMod_train.py new file mode 100644 index 0000000..88ceb6a --- /dev/null +++ b/cosines_with_augmentations/esAppMod_train.py @@ -0,0 +1,276 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim + + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=0 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + outputs = model(**inputs) + cls = outputs.last_hidden_state[:,0,:] + # for training less than half the time, train on easy + y = torch.tensor(y).to(DEVICE) + if epoch < epochs / 2: + loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, outputs, cls, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch % 5 == 0: + torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') + + +torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') +# %% diff --git a/cosines_with_augmentations/esAppMod_train_with_classification.py b/cosines_with_augmentations/esAppMod_train_with_classification.py new file mode 100644 index 0000000..70984e3 --- /dev/null +++ b/cosines_with_augmentations/esAppMod_train_with_classification.py @@ -0,0 +1,305 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm +import pandas as pd +import re +from torch.utils.data import Dataset, DataLoader +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +# %% +SHUFFLES=0 +AMPLIFY_FACTOR=2 +LEARNING_RATE=1e-5 + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + # anchor = False, don't add entity name to each group, simply treat it as a normal mention + entity_sets = [] + if anchor: + for id, mentions in entity_id_mentions.items(): + random.shuffle(mentions) + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) + y.extend([t[1]]*len(t[0])) + yield x, y + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% +############### +# alternate data import strategy +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def generate_random_shuffles(text, n): + words = text.split() # Split the input into words + shuffled_variations = [] + + for _ in range(n): + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string + + return shuffled_variations + + +def shuffle_text(text, n_shuffles=SHUFFLES): + all_processed = [] + # add the original text + all_processed.append(text) + + # Generate random shuffles + shuffled_variations = generate_random_shuffles(text, n_shuffles) + all_processed.extend(shuffled_variations) + + return all_processed + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_string(sentence, corruption_probability=0.01): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < corruption_probability else word + for word in words + ] + return " ".join(corrupted_words) + + +def create_example(index, mention, entity_name): + return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} + +# augment whole dataset +def augment_data(df): + output_list = [] + + for idx,row in df.iterrows(): + index = row['entity_id'] + entity_name = row['entity_name'] + parent_desc = row['mention'] + parent_desc = preprocess_text(parent_desc) + + # add basic example + output_list.append(create_example(index, parent_desc, entity_name)) + + # add shuffled strings + processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) + for desc in processed_descs: + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add corrupted strings + desc = corrupt_string(parent_desc, corruption_probability=0.01) + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # add example with stripped non-alphanumerics + desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces + if (desc != parent_desc): + output_list.append(create_example(index, desc, entity_name)) + + # short sequence amplifier + # short sequences are rare, and we must compensate by including more examples + # also, short sequence don't usually get affected by shuffle + words = parent_desc.split() + word_count = len(words) + if word_count <= 2: + for _ in range(AMPLIFY_FACTOR): + output_list.append(create_example(index, desc, entity_name)) + + new_df = pd.DataFrame(output_list) + return new_df + + + +# %% +def make_entity_id_mentions(df): + entity_id_mentions = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() + return entity_id_mentions + +def make_entity_id_name(df): + entity_id_name = {} + entity_id_list = list(set(df['entity_id'])) + for entity_id in entity_id_list: + # entity_id always matches entity_name, so first value would work + entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] + return entity_id_name + + +# %% +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +bert_model = AutoModel.from_pretrained(MODEL_NAME) + +class BertForClassificationAndTriplet(nn.Module): + def __init__(self, bert_model, num_classes): + super().__init__() + self.bert = bert_model + self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token + logits = self.classifier(cls_embeddings) + return cls_embeddings, logits + +model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) + +optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) +# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +model.to(DEVICE) +model.train() + +losses = [] + + + +for epoch in tqdm(range(epochs)): + total_loss = 0.0 + batch_number = 0 + + augmented_df = augment_data(df) + train_entity_id_mentions = make_entity_id_mentions(augmented_df) + train_entity_id_name = make_entity_id_name(augmented_df) + + data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + random.shuffle(data) + for x,y in batchGenerator(data, batch_size): + + # print(len(x), len(y), end='-->') + optimizer.zero_grad() + + inputs = tokenizer(x, padding=True, return_tensors='pt') + inputs.to(DEVICE) + cls, logits = model( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'] + ) + # for training less than half the time, train on easy + labels = y + labels = [label2id[element] for element in labels] + labels = torch.tensor(labels).to(DEVICE) + y = torch.tensor(y).to(DEVICE) + + + class_loss = F.cross_entropy(logits, labels) + + if epoch < epochs / 2: + triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) + # for training after half the time, train on hard + else: + triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) + + loss = class_loss + triplet_loss + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + batch_number += 1 + + del x, y, cls, logits, loss + torch.cuda.empty_cache() + + # scheduler.step() # Update the learning rate + print(f'epoch loss: {total_loss/batch_number}') + # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") + if epoch % 5 == 0: + # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') + torch.save(model.state_dict(), './checkpoint/classification.pt') + + +# torch.save(model.bert.state_dict(), './checkpoint/classification.pt') +torch.save(model.state_dict(), './checkpoint/classification.pt') +# %% diff --git a/cosines_with_augmentations/loss.py b/cosines_with_augmentations/loss.py new file mode 100644 index 0000000..581a1bb --- /dev/null +++ b/cosines_with_augmentations/loss.py @@ -0,0 +1,186 @@ +# stardard functionalities for computing triplet loss, borrow code from +# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py +import torch +import torch.nn.functional as F +def _pairwise_distances(embeddings, squared=False): + """Compute the 2D matrix of distances between all the embeddings. + Args: + embeddings: tensor of shape (batch_size, embed_dim) + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + pairwise_distances: tensor of shape (batch_size, batch_size) + """ + dot_product = torch.matmul(embeddings, embeddings.t()) + + # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. + # This also provides more numerical stability (the diagonal of the result will be exactly 0). + # shape (batch_size,) + square_norm = torch.diag(dot_product) + + # Compute the pairwise distance matrix as we have: + # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 + # shape (batch_size, batch_size) + distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) + + # Because of computation errors, some distances might be negative so we put everything >= 0.0 + distances[distances < 0] = 0 + + if not squared: + # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) + # we need to add a small epsilon where distances == 0.0 + mask = distances.eq(0).float() + distances = distances + mask * 1e-16 + + distances = (1.0 -mask) * torch.sqrt(distances) + + return distances + +def _get_triplet_mask(labels): + """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. + A triplet (i, j, k) is valid if: + - i, j, k are distinct + - labels[i] == labels[j] and labels[i] != labels[k] + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + """ + # Check that i, j and k are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + i_not_equal_j = indices_not_equal.unsqueeze(2) + i_not_equal_k = indices_not_equal.unsqueeze(1) + j_not_equal_k = indices_not_equal.unsqueeze(0) + + distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k + + + label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + i_equal_j = label_equal.unsqueeze(2) + i_equal_k = label_equal.unsqueeze(1) + + valid_labels = ~i_equal_k & i_equal_j + + return valid_labels & distinct_indices + + +def _get_anchor_positive_triplet_mask(labels): + """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check that i and j are distinct + indices_equal = torch.eye(labels.size(0), device=labels.device).bool() + indices_not_equal = ~indices_equal + + # Check if labels[i] == labels[j] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) + + return labels_equal & indices_not_equal + + +def _get_anchor_negative_triplet_mask(labels): + """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. + Args: + labels: tf.int32 `Tensor` with shape [batch_size] + Returns: + mask: tf.bool `Tensor` with shape [batch_size, batch_size] + """ + # Check if labels[i] != labels[k] + # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) + + return ~(labels.unsqueeze(0) == labels.unsqueeze(1)) + + +# Cell +def batch_hard_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + For each anchor, we get the hardest positive and hardest negative to form a triplet. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + # For each anchor, get the hardest positive + # First, we need to get a mask for every valid positive (they should have same label) + mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() + + # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) + anchor_positive_dist = mask_anchor_positive * pairwise_dist + + # shape (batch_size, 1) + hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) + + # For each anchor, get the hardest negative + # First, we need to get a mask for every valid negative (they should have different labels) + mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() + + # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) + max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) + anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) + + # shape (batch_size,) + hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) + + # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss + tl = hardest_positive_dist - hardest_negative_dist + margin + tl = F.relu(tl) + triplet_loss = tl.mean() + + return triplet_loss + +# Cell +def batch_all_triplet_loss(labels, embeddings, margin, squared=False): + """Build the triplet loss over a batch of embeddings. + We generate all the valid triplets and average the loss over the positive ones. + Args: + labels: labels of the batch, of size (batch_size,) + embeddings: tensor of shape (batch_size, embed_dim) + margin: margin for triplet loss + squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. + If false, output is the pairwise euclidean distance matrix. + Returns: + triplet_loss: scalar tensor containing the triplet loss + """ + # Get the pairwise distance matrix + pairwise_dist = _pairwise_distances(embeddings, squared=squared) + + anchor_positive_dist = pairwise_dist.unsqueeze(2) + anchor_negative_dist = pairwise_dist.unsqueeze(1) + + # Compute a 3D tensor of size (batch_size, batch_size, batch_size) + # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k + # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) + # and the 2nd (batch_size, 1, batch_size) + triplet_loss = anchor_positive_dist - anchor_negative_dist + margin + + + + # Put to zero the invalid triplets + # (where label(a) != label(p) or label(n) == label(a) or a == p) + mask = _get_triplet_mask(labels) + triplet_loss = mask.float() * triplet_loss + + # Remove negative losses (i.e. the easy triplets) + triplet_loss = F.relu(triplet_loss) + + # Count number of positive triplets (where triplet_loss > 0) + valid_triplets = triplet_loss[triplet_loss > 1e-16] + num_positive_triplets = valid_triplets.size(0) + num_valid_triplets = mask.sum() + + fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) + + # Get final mean triplet loss over the positive valid triplets + triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) + + return triplet_loss, fraction_positive_triplets \ No newline at end of file diff --git a/cosines_with_augmentations/results/classify.csv b/cosines_with_augmentations/results/classify.csv new file mode 100644 index 0000000..ff97a0c --- /dev/null +++ b/cosines_with_augmentations/results/classify.csv @@ -0,0 +1,2440 @@ +class_prediction +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +497 +394 +394 +301 +355 +485 +300 +486 +383 +299 +299 +498 +498 +592 +592 +438 +592 +592 +592 +592 +592 +592 +592 +592 +592 +497 +4 +4 +4 +418 +418 +5 +498 +81 +557 +626 +418 +7 +8 +259 +259 +259 +259 +259 +259 +259 +259 +46 +259 +259 +259 +259 +259 +259 +576 +259 +259 +9 +375 +516 +11 +12 +12 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +260 +376 +296 +498 +498 +557 +261 +14 +299 +299 +306 +301 +320 +600 +600 +600 +600 +600 +600 +600 +600 +600 +600 +600 +600 +600 +600 +600 +585 +600 +600 +600 +600 +600 +600 +600 +536 +600 +600 +600 +600 +600 +600 +320 +487 +600 +600 +600 +600 +600 +600 +17 +377 +517 +20 +20 +20 +589 +302 +504 +428 +21 +307 +582 +582 +583 +306 +306 +306 +306 +306 +306 +22 +23 +23 +111 +24 +25 +307 +307 +307 +307 +99 +111 +420 +29 +30 +30 +30 +552 +296 +30 +30 +522 +563 +563 +30 +273 +563 +563 +500 +563 +111 +563 +563 +563 +32 +32 +580 +30 +166 +309 +594 +594 +36 +36 +36 +445 +37 +37 +37 +311 +37 +93 +99 +93 +99 +104 +93 +93 +99 +99 +296 +41 +42 +55 +312 +312 +312 +312 +312 +520 +43 +43 +43 +43 +43 +43 +43 +520 +43 +43 +43 +43 +593 +43 +43 +43 +43 +43 +43 +43 +43 +43 +43 +43 +43 +43 +43 +383 +43 +43 +43 +43 +43 +503 +157 +44 +157 +157 +45 +383 +383 +456 +585 +626 +517 +48 +383 +49 +49 +49 +315 +316 +316 +596 +596 +593 +319 +320 +320 +320 +320 +99 +422 +51 +111 +52 +322 +593 +263 +263 +55 +55 +583 +449 +449 +449 +449 +449 +59 +293 +449 +449 +174 +449 +57 +522 +522 +327 +327 +327 +327 +62 +62 +457 +593 +445 +64 +102 +64 +445 +536 +520 +300 +99 +51 +51 +328 +328 +265 +425 +265 +265 +285 +285 +522 +522 +522 +285 +43 +285 +285 +522 +445 +103 +438 +424 +592 +330 +330 +67 +67 +68 +68 +68 +368 +68 +585 +572 +604 +70 +70 +445 +102 +617 +609 +73 +609 +445 +525 +103 +424 +351 +424 +351 +103 +93 +605 +605 +605 +93 +605 +605 +301 +605 +604 +604 +121 +43 +76 +443 +43 +593 +657 +520 +576 +520 +581 +593 +580 +443 +442 +604 +604 +604 +43 +520 +604 +43 +520 +370 +562 +604 +504 +463 +463 +300 +251 +285 +443 +593 +81 +81 +248 +81 +81 +248 +81 +248 +81 +248 +248 +81 +81 +81 +81 +248 +81 +316 +285 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +497 +609 +609 +107 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +609 +107 +609 +609 +489 +489 +99 +617 +355 +84 +85 +85 +106 +106 +530 +87 +593 +593 +530 +593 +530 +593 +593 +530 +88 +296 +296 +296 +296 +296 +84 +566 +84 +584 +584 +600 +584 +584 +584 +390 +589 +590 +584 +584 +584 +584 +584 +333 +333 +482 +333 +482 +482 +482 +482 +378 +378 +383 +327 +327 +381 +382 +383 +383 +383 +383 +323 +55 +593 +388 +388 +388 +388 +388 +388 +388 +388 +584 +388 +388 +388 +1 +333 +333 +334 +390 +334 +334 +390 +390 +391 +333 +335 +335 +335 +335 +589 +583 +394 +333 +396 +396 +397 +397 +397 +397 +397 +398 +398 +398 +402 +402 +403 +390 +507 +589 +589 +589 +406 +406 +443 +408 +409 +111 +409 +411 +411 +411 +413 +413 +413 +589 +520 +268 +268 +268 +536 +268 +268 +268 +268 +268 +617 +268 +268 +268 +268 +268 +268 +268 +268 +492 +338 +334 +334 +4 +4 +111 +92 +576 +576 +576 +576 +576 +576 +576 +576 +576 +576 +576 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +427 +428 +576 +428 +428 +576 +429 +429 +429 +429 +429 +430 +430 +430 +431 +431 +576 +432 +432 +453 +593 +593 +576 +593 +593 +432 +432 +432 +432 +593 +432 +134 +437 +296 +432 +432 +593 +432 +432 +593 +576 +432 +434 +576 +432 +433 +248 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +434 +576 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +576 +434 +434 +268 +434 +434 +434 +576 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +576 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +268 +268 +268 +434 +434 +434 +434 +434 +268 +434 +434 +268 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +434 +268 +434 +434 +434 +434 +268 +434 +268 +434 +434 +435 +435 +431 +435 +435 +437 +435 +576 +437 +435 +431 +435 +435 +437 +435 +435 +296 +435 +431 +431 +431 +435 +435 +435 +435 +435 +435 +435 +576 +431 +437 +435 +437 +435 +435 +443 +435 +576 +435 +436 +436 +436 +436 +436 +436 +436 +436 +436 +436 +436 +93 +93 +93 +375 +438 +584 +95 +95 +383 +97 +98 +99 +99 +99 +99 +100 +101 +102 +102 +39 +102 +102 +102 +103 +103 +103 +103 +104 +105 +121 +106 +107 +107 +107 +108 +110 +603 +603 +603 +576 +603 +603 +603 +589 +110 +111 +111 +111 +112 +390 +107 +497 +576 +497 +117 +596 +114 +114 +522 +593 +406 +300 +115 +116 +621 +116 +621 +117 +117 +117 +117 +117 +118 +119 +296 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +43 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +296 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +525 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +581 +593 +121 +121 +466 +104 +581 +581 +581 +581 +581 +577 +581 +581 +581 +556 +520 +520 +581 +581 +470 +470 +581 +470 +581 +470 +470 +111 +111 +471 +111 +111 +581 +581 +473 +581 +581 +581 +581 +581 +383 +383 +581 +496 +441 +593 +443 +441 +441 +441 +593 +441 +122 +122 +122 +122 +122 +122 +122 +122 +122 +123 +123 +497 +568 +263 +597 +273 +111 +561 +576 +576 +576 +274 +274 +388 +587 +125 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +507 +343 +343 +344 +344 +344 +128 +449 +442 +442 +442 +442 +84 +368 +131 +134 +107 +134 +134 +333 +134 +134 +134 +535 +541 +134 +134 +594 +134 +600 +134 +343 +88 +343 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +445 +593 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +594 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +580 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +134 +593 +134 +134 +134 +134 +134 +335 +475 +516 +383 +383 +383 +296 +134 +134 +593 +139 +134 +134 +141 +443 +443 +443 +143 +143 +602 +143 +574 +134 +134 +134 +134 +134 +300 +503 +276 +495 +149 +416 +579 +536 +151 +134 +152 +585 +348 +348 +417 +356 +153 +586 +154 +55 +155 +351 +351 +351 +352 +134 +352 +134 +352 +352 +352 +352 +352 +352 +504 +157 +157 +157 +157 +158 +158 +158 +158 +158 +159 +353 +353 +594 +161 +161 +161 +161 +589 +161 +161 +163 +587 +587 +587 +248 +165 +562 +167 +168 +169 +278 +356 +171 +122 +121 +174 +496 +431 +174 +174 +477 +437 +445 +593 +279 +174 +174 +174 +178 +178 +178 +179 +360 +361 +589 +180 +603 +306 +603 +603 +603 +581 +603 +296 +182 +182 +623 +184 +368 +186 +301 +581 +572 +593 +448 +281 +281 +281 +448 +190 +190 +190 +190 +190 +190 +178 +579 +190 +191 +418 +443 +117 +593 +593 +199 +200 +593 +202 +203 +107 +205 +206 +593 +208 +209 +210 +211 +212 +214 +214 +215 +215 +215 +217 +217 +111 +84 +481 +218 +219 +99 +222 +366 +572 +366 +366 +84 +223 +223 +224 +224 +390 +390 +260 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +445 +447 +447 +447 +448 +448 +448 +448 +134 +448 +438 +448 +593 +425 +448 +593 +448 +448 +448 +134 +593 +448 +448 +448 +438 +438 +448 +448 +448 +448 +134 +448 +449 +449 +449 +449 +449 +449 +449 +449 +449 +449 +449 +593 +449 +449 +449 +449 +449 +449 +449 +449 +449 +449 +449 +368 +369 +369 +574 +593 +425 +228 +370 +370 +370 +370 +370 +370 +370 +370 +370 +370 +111 +568 +568 +568 +568 +568 +568 +568 +568 +568 +568 +568 +568 +568 +425 +111 +104 +568 +231 +316 +321 +665 +600 +579 +596 +390 +589 +284 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +603 +284 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +285 +284 +285 +285 +285 +285 +285 +285 +285 +285 +316 +284 +431 +285 +590 +445 +285 +445 +285 +285 +285 +431 +285 +486 +285 +285 +285 +285 +285 +285 +285 +285 +285 +237 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +609 +268 +239 +580 +452 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +560 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +443 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +580 +242 +580 +452 +580 +580 +452 +452 +452 +580 +452 +580 +452 +452 +452 +452 +99 +452 +452 +452 +452 +355 +452 +306 +452 +452 +452 +452 +580 +452 +452 +580 +580 +452 +452 +452 +580 +580 +580 +452 +452 +580 +452 +580 +452 +452 +452 +580 +452 +452 +452 +452 +452 +452 +452 +452 +334 +580 +580 +452 +580 +452 +452 +452 +452 +452 +452 +452 +580 +452 +580 +580 +452 +452 +580 +452 +452 +580 +452 +580 +580 +452 +452 +580 +452 +580 +452 +452 +452 +580 +452 +452 +452 +580 +580 +452 +580 +452 +452 +452 +452 +452 +99 +580 +99 +452 +452 +452 +452 +452 +593 +452 +452 +452 +580 +580 +452 +452 +580 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +580 +452 +452 +452 +43 +453 +452 +452 +452 +580 +306 +452 +452 +452 +452 +452 +580 +583 +452 +580 +43 +452 +452 +452 +452 +242 +452 +580 +580 +452 +452 +580 +452 +452 +242 +452 +580 +452 +580 +452 +580 +580 +452 +306 +452 +452 +580 +580 +452 +452 +580 +452 +580 +452 +596 +316 +103 +452 +452 +452 +452 +580 +452 +580 +452 +452 +580 +452 +580 +522 +580 +306 +242 +452 +580 +452 +355 +452 +580 +452 +522 +580 +452 +296 +580 +580 +580 +580 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +452 +580 +107 +452 +452 +452 +452 +452 +580 +452 +452 +452 +580 +580 +580 +452 +452 +580 +580 +580 +452 +452 +580 +452 +452 +580 +452 +452 +452 +452 +452 +452 +452 +452 +583 +452 +452 +452 +242 +452 +452 +452 +355 +580 +452 +452 +452 +580 +452 +580 +452 +580 +296 +452 +452 +452 +580 +452 +580 +452 +452 +580 +452 +580 +580 +580 +452 +242 +452 +580 +580 +580 +580 +580 +452 +452 +452 +452 +522 +580 +452 +452 +452 +452 +452 +580 +452 +452 +580 +580 +580 +580 +452 +580 +452 +452 +580 +452 +580 +452 +452 +452 +452 +452 +452 +580 +452 +580 +580 +580 +452 +580 +452 +452 +452 +580 +452 +580 +580 +452 +452 +452 +242 +242 +242 +452 +242 +580 +452 +452 +452 +452 +452 +452 +452 +452 +580 +452 +452 +452 +452 +452 +452 +12 +241 +242 +243 +167 +111 +247 +247 +247 +248 +425 +134 +289 +250 +250 +250 +597 +251 +252 +253 +254 +255 +256 +601 +511 +445 +512 +442 +659 +515 +303 +516 +661 +517 +517 +517 +517 +520 +355 +520 +520 +520 +521 +596 +522 +522 +596 +523 +498 +524 +525 +525 +526 +84 +529 +512 +530 +530 +531 +547 +572 +572 +574 +538 +43 +276 +383 +540 +547 +545 +547 +547 +553 +609 +557 +557 +558 +558 +562 +591 +591 +597 +593 +593 +593 +370 +111 +306 +442 +597 +598 +572 +579 +579 +579 +579 +22 +504 +661 +516 +43 +520 +520 +667 +520 +675 +675 +675 +134 +134 +675 +301 +43 +84 +694 +694 +512 diff --git a/cosines_with_augmentations/results/output.txt b/cosines_with_augmentations/results/output.txt new file mode 100644 index 0000000..fdc60dc --- /dev/null +++ b/cosines_with_augmentations/results/output.txt @@ -0,0 +1,5 @@ +Top-1 accuracy: 0.8257482574825749 +Top-3 accuracy: 0.9106191061910619 +Top-5 accuracy: 0.9261992619926199 +Top-10 accuracy: 0.942189421894219 +0.0 0.82121104 diff --git a/reference_code/dataload_with_augmentation.py b/reference_code/dataload_with_augmentation.py new file mode 100644 index 0000000..0eeadc5 --- /dev/null +++ b/reference_code/dataload_with_augmentation.py @@ -0,0 +1,227 @@ +# this code performs dataloading with text augmentation + +# %% +from torch.utils.data import Dataset, DataLoader +import pandas as pd +import torch +from transformers import ( + AutoTokenizer, +) +from functools import partial +import re +import random + +# %% +# PARAMETERS +SAMPLES=5 + +# %% +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def shuffle_text(text, prob=0.2): + if random.random() < prob: + words = text.split() # Split the input into words + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_text = " ".join(shuffled) # Join the words back into a string + else: + shuffled_text = text + + return shuffled_text + + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_text(sentence, prob=0.05): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < prob else word + for word in words + ] + return " ".join(corrupted_words) + +def strip_nonalphanumerics(desc, prob=0.5): + desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces + return desc + + +# %% +def augment(row): + """ + function to augment "mention" string input + returns the string input with slight variation + """ + desc = row['mention'] + # we always apply preprocess + desc = preprocess_text(desc) + + desc = shuffle_text(desc, prob=0.5) + desc = corrupt_text(desc, prob=0.5) + desc = strip_nonalphanumerics(desc, prob=0.5) + + return desc + + + + +# %% +# custom dataset +# we want to sample n samples from each class +# sample_size refers to the number of samples per class +def sample_from_df(df, sample_size_per_class=5): + sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names + .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) + .reset_index(drop=True)) + + return sampled_df + + +# %% +class DynamicDataset(Dataset): + def __init__(self, df, sample_size_per_class): + """ + Args: + df (pd.DataFrame): Original DataFrame with class (id) and data columns. + sample_size_per_class (int): Number of samples to draw per class for each epoch. + """ + self.df = df + self.sample_size_per_class = sample_size_per_class + self.current_data = None + self.regenerate_data() # Generate the initial dataset + + def regenerate_data(self): + """ + Generate a new sampled dataset for the current epoch. + + dynamic callback function to regenerate data each time we call this + method, it updates the current_data we can: + + - re-sample the dataframe for a new set of n_samples + - generate fresh augmentations this effectively + + This allows us to re-sample and re-augment at the start of each epoch + """ + # Sample `sample_size_per_class` rows per class + sampled_df = sample_from_df(self.df, self.sample_size_per_class) + + # Store the tokenized data with labels + self.current_data = sampled_df + + def __len__(self): + return len(self.current_data) + + def __getitem__(self, idx): + # do the transform here + row = self.current_data.iloc[idx].to_dict() + + # perform text augmentation here + # independent function calls might introduce changes + mention_0 = augment(row) + mention_1 = augment(row) + return { + 'training_id': row['training_id'], + 'mention_0': mention_0, + 'mention_1': mention_1, + } + + +# %% +dataset = DynamicDataset(df, sample_size_per_class=SAMPLES) +dataset[0] + + +# %% +def custom_collate_fn(batch, tokenizer): + # batch is just a list of dictionaries + label_list = [item['training_id'] for item in batch] + mention_0_list = [item['mention_0'] for item in batch] + mention_1_list = [item['mention_1'] for item in batch] + + # we can do the tokenization here + tokenized_batch_0 = tokenizer( + mention_0_list, + truncation=True, + padding=True, + return_tensors='pt' + ) + + tokenized_batch_1 = tokenizer( + mention_1_list, + truncation=True, + padding=True, + return_tensors='pt' + ) + + + label_list = torch.tensor(label_list) + + return { + 'input_ids_0': tokenized_batch_0['input_ids'], + 'attention_mask_0': tokenized_batch_0['attention_mask'], + 'input_ids_1': tokenized_batch_1['input_ids'], + 'attention_mask_1': tokenized_batch_1['attention_mask'], + 'labels': label_list, + } + +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False) +custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer) +dataloader = DataLoader( + dataset, + batch_size=8, + collate_fn=custom_collate_fn_with_tokenizer +) + + +# %% +next(iter(dataloader)) +# %% diff --git a/tackle_container/biomedical_results/output.txt b/tackle_container/biomedical_results/output.txt new file mode 100644 index 0000000..ccd8363 --- /dev/null +++ b/tackle_container/biomedical_results/output.txt @@ -0,0 +1,12 @@ +Top-1 accuracy: 0.7107843137254902 +Top-3 accuracy: 0.7990196078431373 +Top-5 accuracy: 0.8137254901960784 +Top-10 accuracy: 0.8529411764705882 +Top-1 accuracy: 0.6862745098039216 +Top-3 accuracy: 0.7696078431372549 +Top-5 accuracy: 0.7990196078431373 +Top-10 accuracy: 0.8480392156862745 +Top-1 accuracy: 0.696078431372549 +Top-3 accuracy: 0.7892156862745098 +Top-5 accuracy: 0.8088235294117647 +Top-10 accuracy: 0.8382352941176471 diff --git a/tackle_container/biomedical_train.py b/tackle_container/biomedical_train.py index 9893625..5f04e31 100644 --- a/tackle_container/biomedical_train.py +++ b/tackle_container/biomedical_train.py @@ -14,11 +14,10 @@ from transformers import AutoTokenizer, AutoModel from data import generate_train_entity_sets -from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work +from tqdm import tqdm from loss import batch_all_triplet_loss, batch_hard_triplet_loss from sklearn.neighbors import KNeighborsClassifier import numpy as np -import logging def setup(rank, world_size): @@ -73,13 +72,10 @@ def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, ma loss.backward() optimizer.step() - # logging.info(f'{epoch} {len(x)} {loss.item()}') epoch_loss.append(loss.item()) epoch_len.append(len(x)) # del inputs, cls, loss # torch.cuda.empty_cache() - logging.info(f'{DEVICE}{epoch_len}') - logging.info(f'{DEVICE}{epoch_loss}') def check_label(predicted_cui: str, golden_cui: str) -> int: """ @@ -127,8 +123,7 @@ def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, m # print(np.min(distances), np.max(distances)) def save_checkpoint(model, res, epoch, dataName): - logging.info(f'Saving model {epoch} {res} ') - torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt') + torch.save(model.state_dict(), './checkpoint/' + dataName + '.pt') class Model(nn.Module): def __init__(self,MODEL_NAME): @@ -146,12 +141,12 @@ def main(rank, world_size, config): setup(rank, world_size) dataName = config['DEFAULT']['dataName'] - logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO) vocab = defaultdict(set) - with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f: + with open('../biomedical/' + dataName + '/' + config['train']['dictionary']) as f: for line in f: - vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower()) + line_list = line.strip().split('||') + vocab[line_list[0]].add(line_list[1].lower()) cui_to_id, id_to_cui = {}, {} vocab_entity_id_mentions = {} @@ -167,7 +162,7 @@ def main(rank, world_size, config): vocab_ids.extend([id]*len(mentions)) test_mentions, test_cuis = [], [] - with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f: + with open('../biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f: for line in f: test_cuis.append(line.strip().split('||')[-1]) test_mentions.append(line.strip().split('||')[-2].lower()) @@ -188,8 +183,7 @@ def main(rank, world_size, config): ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) best = 0 - if rank == 0: - logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' ) + best_res = [] for epoch in tqdm(range(epochs)): trainDataLoader.sampler.set_epoch(epoch) @@ -197,11 +191,18 @@ def main(rank, world_size, config): # if rank == 0 and epoch % 2 == 0: if rank == 0: res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer) - logging.info(f'{epoch} {res}') if res[0] > best: best = res[0] + best_res = res save_checkpoint(ddp_model.module, res, epoch, dataName) + with open("biomedical_results/output.txt", "a") as f: + print('new best ----', file=f) + for idx,n in enumerate([1,3,5,10]): + print(f'Top-{n:<3} accuracy: {best_res[idx]}', file=f) + dist.barrier() + + cleanup() if __name__ == '__main__': diff --git a/tackle_container/config.ini b/tackle_container/config.ini new file mode 100644 index 0000000..118deb3 --- /dev/null +++ b/tackle_container/config.ini @@ -0,0 +1,27 @@ +[DEFAULT] +dataName = ncbi +# dataName = bc5cdr-disease +# dataName = bc5cdr-chemical +# dataName = bc2gm + +[train] +epochs = 2 +batch_size = 40 +lr = 1e-5 + +dictionary = test_dictionary.txt +# or processed_dev_oneFile Note: bc2gm has no dev split +test_set = processed_test_refined +ckt_path = ./checkpoint/ + + +[model] +margin = 2 +model_name = dmis-lab/biobert-v1.1 + +[eval] +batch_size = 200 + +[data] +group_size = 4 +shuffle = True \ No newline at end of file diff --git a/tackle_container/understanding_batch_generation.py b/tackle_container/understanding_batch_generation.py new file mode 100644 index 0000000..78bc1c7 --- /dev/null +++ b/tackle_container/understanding_batch_generation.py @@ -0,0 +1,84 @@ +# %% +import torch +import json +import random +import numpy as np +from transformers import AutoTokenizer +from transformers import AutoModel +from loss import batch_all_triplet_loss, batch_hard_triplet_loss +from sklearn.neighbors import KNeighborsClassifier +from tqdm import tqdm + + +# %% +def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): + # split entity mentions into groups + entity_sets = [] + # anchor means to use the entity_name as an anchor + if anchor: + # entity_id_mentions is a list of dicts, each dict is an id and a list of mentions + for id, mentions in entity_id_mentions.items(): + # shuffle the mentions + random.shuffle(mentions) + # make batches of at most group size - 10 + positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] + # the first element is always the entity_name + # this is why we use (group_size - 1) + anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] + entity_sets.extend(anchor_positive) + else: + # in this case, there is no entity_name in each group + for id, mentions in entity_id_mentions.items(): + group = list(set([entity_id_name[id]] + mentions)) + random.shuffle(group) + positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] + entity_sets.extend(positives) + return entity_sets + +# the batch generator selects batch_size entries from the "data" +# but actually the "data" is a list of 10 items or less +def batchGenerator(data, batch_size): + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + x, y = [], [] + for t in batch: + x.extend(t[0]) # this is the list of mentions + y.extend([t[1]]*len(t[0])) # this multiplies a single label by the number of mentions + yield x, y + +# %% + + +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + +# %% + +num_sample_per_class = 10 # samples in each group +batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class +margin = 2 +epochs = 200 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModel.from_pretrained(MODEL_NAME) +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) + +model.to(DEVICE) +model.train() + +losses = [] + +data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) + +# %% +random.shuffle(data) + \ No newline at end of file diff --git a/vicreg/dataload.py b/vicreg/dataload.py new file mode 100644 index 0000000..e5c285f --- /dev/null +++ b/vicreg/dataload.py @@ -0,0 +1,227 @@ +# this code performs dataloading with text augmentation + +# %% +from torch.utils.data import Dataset, DataLoader +import pandas as pd +import torch +from transformers import ( + AutoTokenizer, +) +from functools import partial +import re +import random + +# %% +# PARAMETERS +SAMPLES=5 + +# %% +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + +# %% +############################################################## +# augmentation code + +# basic preprocessing +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def shuffle_text(text, prob=0.2): + if random.random() < prob: + words = text.split() # Split the input into words + shuffled = words[:] # Copy the word list to avoid in-place modification + random.shuffle(shuffled) # Randomly shuffle the words + shuffled_text = " ".join(shuffled) # Join the words back into a string + else: + shuffled_text = text + + return shuffled_text + + +def corrupt_word(word): + """Corrupt a single word using random corruption techniques.""" + if len(word) <= 1: # Skip corruption for single-character words + return word + + corruption_type = random.choice(["delete", "swap"]) + + if corruption_type == "delete": + # Randomly delete a character + idx = random.randint(0, len(word) - 1) + word = word[:idx] + word[idx + 1:] + + elif corruption_type == "swap": + # Swap two adjacent characters + if len(word) > 1: + idx = random.randint(0, len(word) - 2) + word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) + + + return word + +def corrupt_text(sentence, prob=0.05): + """Corrupt each word in the string with a given probability.""" + words = sentence.split() + corrupted_words = [ + corrupt_word(word) if random.random() < prob else word + for word in words + ] + return " ".join(corrupted_words) + +def strip_nonalphanumerics(desc, prob=0.5): + desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces + return desc + + +# %% +def augment(row): + """ + function to augment "mention" string input + returns the string input with slight variation + """ + desc = row['mention'] + # we always apply preprocess + desc = preprocess_text(desc) + + desc = shuffle_text(desc, prob=1.0) + desc = corrupt_text(desc, prob=1.0) + desc = strip_nonalphanumerics(desc, prob=0.5) + + return desc + + + + +# %% +# custom dataset +# we want to sample n samples from each class +# sample_size refers to the number of samples per class +def sample_from_df(df, sample_size_per_class=5): + sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names + .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x)))) + .reset_index(drop=True)) + + return sampled_df + + +# %% +class DynamicDataset(Dataset): + def __init__(self, df, sample_size_per_class): + """ + Args: + df (pd.DataFrame): Original DataFrame with class (id) and data columns. + sample_size_per_class (int): Number of samples to draw per class for each epoch. + """ + self.df = df + self.sample_size_per_class = sample_size_per_class + self.current_data = None + self.regenerate_data() # Generate the initial dataset + + def regenerate_data(self): + """ + Generate a new sampled dataset for the current epoch. + + dynamic callback function to regenerate data each time we call this + method, it updates the current_data we can: + + - re-sample the dataframe for a new set of n_samples + - generate fresh augmentations this effectively + + This allows us to re-sample and re-augment at the start of each epoch + """ + # Sample `sample_size_per_class` rows per class + sampled_df = sample_from_df(self.df, self.sample_size_per_class) + + # Store the tokenized data with labels + self.current_data = sampled_df + + def __len__(self): + return len(self.current_data) + + def __getitem__(self, idx): + # do the transform here + row = self.current_data.iloc[idx].to_dict() + + # perform text augmentation here + # independent function calls might introduce changes + mention_0 = augment(row) + mention_1 = augment(row) + return { + 'training_id': row['training_id'], + 'mention_0': mention_0, + 'mention_1': mention_1, + } + + +# %% +dataset = DynamicDataset(df, sample_size_per_class=SAMPLES) +dataset[0] + + +# %% +def custom_collate_fn(batch, tokenizer): + # batch is just a list of dictionaries + label_list = [item['training_id'] for item in batch] + mention_0_list = [item['mention_0'] for item in batch] + mention_1_list = [item['mention_1'] for item in batch] + + # we can do the tokenization here + tokenized_batch_0 = tokenizer( + mention_0_list, + truncation=True, + padding=True, + return_tensors='pt' + ) + + tokenized_batch_1 = tokenizer( + mention_1_list, + truncation=True, + padding=True, + return_tensors='pt' + ) + + + label_list = torch.tensor(label_list) + + return { + 'input_ids_0': tokenized_batch_0['input_ids'], + 'attention_mask_0': tokenized_batch_0['attention_mask'], + 'input_ids_1': tokenized_batch_1['input_ids'], + 'attention_mask_1': tokenized_batch_1['attention_mask'], + 'labels': label_list, + } + +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False) +custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer) +dataloader = DataLoader( + dataset, + batch_size=8, + collate_fn=custom_collate_fn_with_tokenizer +) + + +# %% +next(iter(dataloader)) +# %% diff --git a/vicreg/results/output.txt b/vicreg/results/output.txt new file mode 100644 index 0000000..904774f --- /dev/null +++ b/vicreg/results/output.txt @@ -0,0 +1,5 @@ +Top-1 accuracy: 0.01845018450184502 +Top-3 accuracy: 0.02870028700287003 +Top-5 accuracy: 0.03936039360393604 +Top-10 accuracy: 0.08200082000820008 +0.0 0.7957323 diff --git a/vicreg/train.py b/vicreg/train.py new file mode 100644 index 0000000..84949ec --- /dev/null +++ b/vicreg/train.py @@ -0,0 +1,387 @@ +# %% +# %% +from torch.utils.data import Dataset, DataLoader + +# from datasets import load_from_disk +import os +import json + +os.environ['NCCL_P2P_DISABLE'] = '1' +os.environ['NCCL_IB_DISABLE'] = '1' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +from dataclasses import dataclass +import re +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import ( + AutoTokenizer, + AutoModel, + DataCollatorWithPadding, + Trainer, + EarlyStoppingCallback, + TrainingArguments, + TrainerCallback +) +import evaluate +import numpy as np +import pandas as pd +from functools import partial +import warnings + +from tqdm import tqdm + +from dataload import DynamicDataset, custom_collate_fn + +from sklearn.neighbors import KNeighborsClassifier + +torch.set_float32_matmul_precision('high') + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # NumPy random + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) # If using multiple GPUs + torch.backends.cudnn.deterministic = True # Ensure deterministic behavior + torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility + +set_seed(42) + +SAMPLES=40 +BATCH_SIZE=256 + +# %% +################################################### +# import code +# import training file +data_path = '../esAppMod_data_import/train.csv' +df = pd.read_csv(data_path, skipinitialspace=True) +# rather than use pattern, we use the real thing and property +entity_ids = df['entity_id'].to_list() +target_id_list = sorted(list(set(entity_ids))) + +id2label = {} +label2id = {} +for idx, val in enumerate(target_id_list): + id2label[idx] = val + label2id[val] = idx + +df["training_id"] = df["entity_id"].map(label2id) + + +# %% +# make our dataset and dataloader +# MODEL_NAME = "distilbert/distilbert-base-uncased" +MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=False) +dataset = DynamicDataset(df, sample_size_per_class=SAMPLES) +custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer) +dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + collate_fn=custom_collate_fn_with_tokenizer +) + +# %% +# enable BERT with projection layer + +class VICRegProjection(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(VICRegProjection, self).__init__() + self.projection = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, x): + return self.projection(x) + + +class BertWithVICReg(nn.Module): + def __init__(self, bert_model, projection_dim=256): + super(BertWithVICReg, self).__init__() + self.bert = bert_model + hidden_size = bert_model.config.hidden_size + self.projection = VICRegProjection(input_dim=hidden_size, hidden_dim=hidden_size, output_dim=projection_dim) + + def forward(self, input_ids, attention_mask=None): + outputs = self.bert(input_ids, attention_mask=attention_mask) + pooled_output = outputs.last_hidden_state[:,0,:] + projected_embeddings = self.projection(pooled_output) + return projected_embeddings + +####################################################### +# %% +@dataclass +class Hyperparameters: + loss_constant_factor: float = 1 + invariance_loss_weight: float = 25.0 + variance_loss_weight: float = 25.0 + covariance_loss_weight: float = 1.0 + variance_loss_epsilon: float = 1e-5 + + +# compute vicreg loss +def get_vicreg_loss(z_a, z_b, hparams): + assert z_a.shape == z_b.shape and len(z_a.shape) == 2 + + # invariance loss + loss_inv = F.mse_loss(z_a, z_b) + + # variance loss + std_z_a = torch.sqrt(z_a.var(dim=0) + hparams.variance_loss_epsilon) + std_z_b = torch.sqrt(z_b.var(dim=0) + hparams.variance_loss_epsilon) + loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max + loss_v_b = torch.mean(F.relu(1 - std_z_b)) + loss_var = loss_v_a + loss_v_b + + # covariance loss + N, D = z_a.shape + z_a = z_a - z_a.mean(dim=0) + z_b = z_b - z_b.mean(dim=0) + cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD + cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD + loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D + loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D + loss_cov = loss_c_a + loss_c_b + + weighted_inv = loss_inv * hparams.invariance_loss_weight + weighted_var = loss_var * hparams.variance_loss_weight + weighted_cov = loss_cov * hparams.covariance_loss_weight + + loss = weighted_inv + weighted_var + weighted_cov + + + return loss + + + +# %% +# +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +bert_model = AutoModel.from_pretrained(MODEL_NAME) +# bert_hidden_size = bert_model.config.hidden_size +# projection_model = VICRegProjection( +# input_dim=bert_hidden_size, +# hidden_dim=bert_hidden_size, +# output_dim=256 +# ) +# need to allocate individual component of the model +# bert_model.to(DEVICE) +# projection_model.to(DEVICE) +model = BertWithVICReg(bert_model, projection_dim=128) +model.to(DEVICE) + +# params = list(bert_model.parameters()) + list(projection_model.parameters()) +params = model.parameters() +optimizer = torch.optim.AdamW(params, lr=5e-6) +hparams = Hyperparameters() + +losses = [] + +# # %% +# batch = next(iter(dataloader)) +# input_ids_0 = batch['input_ids_0'].to(DEVICE) +# attention_mask_0 = batch['attention_mask_0'].to(DEVICE) +# +# # %% +# # outputs from reprojection layer +# bert_output = model( +# input_ids=input_ids_0, +# attention_mask=attention_mask_0 +# ) + + + +# %% +# parameters +epochs = 80 + +for epoch in tqdm(range(epochs)): + dataset.regenerate_data() + for batch in dataloader: + optimizer.zero_grad() + + # compute cls 0 + input_ids_0 = batch['input_ids_0'].to(DEVICE) + attention_mask_0 = batch['attention_mask_0'].to(DEVICE) + # outputs from reprojection layer + outputs_0 = model( + input_ids=input_ids_0, + attention_mask=attention_mask_0 + ) + + # compute cls 1 + input_ids_1 = batch['input_ids_1'].to(DEVICE) + attention_mask_1 = batch['attention_mask_1'].to(DEVICE) + # outputs from reprojection layer + outputs_1 = model( + input_ids=input_ids_1, + attention_mask=attention_mask_1 + ) + + loss = get_vicreg_loss(outputs_0, outputs_1, hparams=hparams) + + + loss.backward() + optimizer.step() + # print(epoch, loss) + losses.append(loss) + torch.cuda.empty_cache() + + print(loss.detach().item()) + + +# %% +torch.save(model.state_dict(), './checkpoint/simple.pt') + +#################################################### +# %% +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +state_dict = torch.load('./checkpoint/simple.pt') +model = BertWithVICReg(bert_model, projection_dim=256) +model.load_state_dict(state_dict) + +# %% +# Step 2: Load the state dictionary +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' +# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +# state_dict = torch.load('./checkpoint/siamese.pt') +# model = model.bert +# %% + +# Step 3: Apply the state dictionary to the model +model.to(DEVICE) +model.eval() + +# %% +with open('../esAppMod/tca_entities.json', 'r') as file: + entities = json.load(file) +all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} + +with open('../esAppMod/train.json', 'r') as file: + train = json.load(file) +train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} +train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} + + +# %% +def preprocess_text(text): + # 1. Make all uppercase + text = text.lower() + + # standardize spacing + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +with open('../esAppMod/infer.json', 'r') as file: + test = json.load(file) +x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] +y_test = [d['entity_id'] for _, d in test['data'].items()] +train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) +train_entities = [preprocess_text(element) for element in train_entities] + +def batch_list(data, batch_size): + """Yield successive n-sized chunks from data.""" + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + +batches = batch_list(train_entities, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + # output = outputs.last_hidden_state[:,0,:] + outputs = outputs.detach().cpu().numpy() + embedding_list.append(outputs) + +cls = np.concatenate(embedding_list) +# %% +torch.cuda.empty_cache() + +# %% + +batches = batch_list(x_test, 64) + +embedding_list = [] +for batch in batches: + inputs = tokenizer(batch, padding=True, return_tensors='pt') + outputs = model( + input_ids=inputs['input_ids'].to(DEVICE), + attention_mask=inputs['attention_mask'].to(DEVICE) + ) + # output = outputs.last_hidden_state[:,0,:] + outputs = outputs.detach().cpu().numpy() + embedding_list.append(outputs) + +cls_test = np.concatenate(embedding_list) + + +# %% +knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels) +n_neighbors = [1, 3, 5, 10] + +for n in n_neighbors: + distances, indices = knn.kneighbors(cls_test, n_neighbors=n) + num = 0 + for a,b in zip(y_test, indices): + b = [labels[i] for i in b] + if a in b: + num += 1 + print(f'Top-{n:<3} accuracy: {num / len(y_test)}') +print(np.min(distances), np.max(distances)) + + + +# with open("results/output.txt", "w") as f: +# for n in n_neighbors: +# distances, indices = knn.kneighbors(cls_test, n_neighbors=n) +# num = 0 +# for a,b in zip(y_test, indices): +# b = [labels[i] for i in b] +# if a in b: +# num += 1 +# print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f) +# print(np.min(distances), np.max(distances), file=f) + +# %% + +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt +# %% + +# Reduce dimensions with t-SNE +tsne = TSNE(n_components=2, random_state=42) +embeddings= cls +embeddings_reduced = tsne.fit_transform(embeddings) + + +plt.figure(figsize=(10, 8)) +scatter = plt.scatter(embeddings_reduced[:, 0], embeddings_reduced[:, 1], c=labels, cmap='viridis', alpha=0.6) +plt.colorbar(scatter) +plt.xlabel('Component 1') +plt.ylabel('Component 2') +plt.title('Visualization of Embeddings') +plt.show() + +# %%