# %% import torch import json import random import numpy as np from transformers import AutoTokenizer from transformers import AutoModel from loss import ( batch_all_triplet_loss, batch_hard_triplet_loss, batch_all_soft_margin_triplet_loss, batch_hard_soft_margin_triplet_loss) from sklearn.neighbors import KNeighborsClassifier from tqdm import tqdm import pandas as pd import re from torch.utils.data import Dataset, DataLoader import torch.optim as optim import torch.nn as nn import torch.nn.functional as F torch.set_float32_matmul_precision('high') def set_seed(seed): """ Set the random seed for reproducibility. """ random.seed(seed) # Python random module np.random.seed(seed) # NumPy random torch.manual_seed(seed) # PyTorch CPU torch.cuda.manual_seed(seed) # PyTorch GPU torch.cuda.manual_seed_all(seed) # If using multiple GPUs torch.backends.cudnn.deterministic = True # Ensure deterministic behavior torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility set_seed(42) # %% SHUFFLES=1 AMPLIFY_FACTOR=1 LEARNING_RATE=1e-4 DEVICE = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu') # %% EVAL_FILE="top1_curves/hybrid_output.txt" with open(EVAL_FILE, "w") as f: pass # %% def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): # split entity mentions into groups # anchor = False, don't add entity name to each group, simply treat it as a normal mention entity_sets = [] if anchor: for id, mentions in entity_id_mentions.items(): random.shuffle(mentions) positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] entity_sets.extend(anchor_positive) else: for id, mentions in entity_id_mentions.items(): group = list(set([entity_id_name[id]] + mentions)) random.shuffle(group) positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] entity_sets.extend(positives) return entity_sets def batchGenerator(data, batch_size): for i in range(0, len(data), batch_size): batch = data[i:i+batch_size] x, y = [], [] for t in batch: x.extend(t[0]) y.extend([t[1]]*len(t[0])) yield x, y with open('../esAppMod/tca_entities.json', 'r') as file: entities = json.load(file) all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} with open('../esAppMod/train.json', 'r') as file: train = json.load(file) train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} # %% ############### # alternate data import strategy ################################################### # import code # import training file data_path = '../esAppMod_data_import/train.csv' df = pd.read_csv(data_path, skipinitialspace=True) # rather than use pattern, we use the real thing and property entity_ids = df['entity_id'].to_list() target_id_list = sorted(list(set(entity_ids))) id2label = {} label2id = {} for idx, val in enumerate(target_id_list): id2label[idx] = val label2id[val] = idx df["training_id"] = df["entity_id"].map(label2id) # %% ############################################################## # augmentation code # basic preprocessing def preprocess_text(text): # 1. Make all uppercase text = text.lower() # standardize spacing text = re.sub(r'\s+', '_', text).strip() return text def generate_random_shuffles(text, n): words = text.split() # Split the input into words shuffled_variations = [] for _ in range(n): shuffled = words[:] # Copy the word list to avoid in-place modification random.shuffle(shuffled) # Randomly shuffle the words shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string return shuffled_variations def shuffle_text(text, n_shuffles=SHUFFLES): all_processed = [] # add the original text all_processed.append(text) # Generate random shuffles shuffled_variations = generate_random_shuffles(text, n_shuffles) all_processed.extend(shuffled_variations) return all_processed def corrupt_word(word): """Corrupt a single word using random corruption techniques.""" if len(word) <= 1: # Skip corruption for single-character words return word corruption_type = random.choice(["delete", "swap"]) if corruption_type == "delete": # Randomly delete a character idx = random.randint(0, len(word) - 1) word = word[:idx] + word[idx + 1:] elif corruption_type == "swap": # Swap two adjacent characters if len(word) > 1: idx = random.randint(0, len(word) - 2) word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) return word def corrupt_string(sentence, corruption_probability=0.01): """Corrupt each word in the string with a given probability.""" words = sentence.split() corrupted_words = [ corrupt_word(word) if random.random() < corruption_probability else word for word in words ] return " ".join(corrupted_words) def create_example(index, mention, entity_name): return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} # augment whole dataset def augment_data(df): output_list = [] for idx,row in df.iterrows(): index = row['entity_id'] entity_name = row['entity_name'] parent_desc = row['mention'] parent_desc = preprocess_text(parent_desc) # add basic example output_list.append(create_example(index, parent_desc, entity_name)) # # add shuffled strings # processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) # for desc in processed_descs: # if (desc != parent_desc): # output_list.append(create_example(index, desc, entity_name)) # add corrupted strings desc = corrupt_string(parent_desc, corruption_probability=0.01) if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add example with stripped non-alphanumerics desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # short sequence amplifier # short sequences are rare, and we must compensate by including more examples # also, short sequence don't usually get affected by shuffle words = parent_desc.split() word_count = len(words) if word_count <= 2: for _ in range(AMPLIFY_FACTOR): output_list.append(create_example(index, desc, entity_name)) new_df = pd.DataFrame(output_list) return new_df # %% def make_entity_id_mentions(df): entity_id_mentions = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() return entity_id_mentions def make_entity_id_name(df): entity_id_name = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: # entity_id always matches entity_name, so first value would work entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] return entity_id_name # evaluation def run_evaluation(model, tokenizer): def preprocess_text(text): # 1. Make all uppercase text = text.lower() # standardize spacing text = re.sub(r'\s+', ' ', text).strip() return text with open('../esAppMod/tca_entities.json', 'r') as file: entities = json.load(file) all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} with open('../esAppMod/train.json', 'r') as file: train = json.load(file) train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} with open('../esAppMod/infer.json', 'r') as file: test = json.load(file) x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()] y_test = [d['entity_id'] for _, d in test['data'].items()] train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys()) train_entities = [preprocess_text(element) for element in train_entities] def batch_list(data, batch_size): """Yield successive n-sized chunks from data.""" for i in range(0, len(data), batch_size): yield data[i:i + batch_size] batches = batch_list(train_entities, 64) embedding_list = [] for batch in batches: inputs = tokenizer(batch, padding=True, return_tensors='pt') outputs = model( input_ids=inputs['input_ids'].to(DEVICE), attention_mask=inputs['attention_mask'].to(DEVICE) ) output = outputs.last_hidden_state[:,0,:] output = output.detach().cpu().numpy() embedding_list.append(output) cls = np.concatenate(embedding_list) batches = batch_list(x_test, 64) embedding_list = [] for batch in batches: inputs = tokenizer(batch, padding=True, return_tensors='pt') outputs = model( input_ids=inputs['input_ids'].to(DEVICE), attention_mask=inputs['attention_mask'].to(DEVICE) ) output = outputs.last_hidden_state[:,0,:] output = output.detach().cpu().numpy() embedding_list.append(output) cls_test = np.concatenate(embedding_list) knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels) with open(EVAL_FILE, "a") as f: # only compute top-1 distances, indices = knn.kneighbors(cls_test, n_neighbors=1) num = 0 for a,b in zip(y_test, indices): b = [labels[i] for i in b] if a in b: num += 1 print(f'{num / len(y_test)}', file=f) # %% num_sample_per_class = 10 # samples in each group batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = 2 epochs = 200 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) bert_model = AutoModel.from_pretrained(MODEL_NAME) class BertForClassificationAndTriplet(nn.Module): def __init__(self, bert_model, num_classes): super().__init__() self.bert = bert_model self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) def forward(self, input_ids, attention_mask=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token logits = self.classifier(cls_embeddings) return cls_embeddings, logits model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) model.to(DEVICE) model.train() losses = [] def linear_decay(epoch, max_epochs, initial_lr, final_lr): """ Calculate the linearly decayed learning rate. """ return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr) for epoch in tqdm(range(epochs)): total_loss = 0.0 total_cross = 0.0 total_triplet = 0.0 batch_number = 0 # lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6) # # Update optimizer's learning rate # for param_group in optimizer.param_groups: # param_group['lr'] = lr if epoch % 10 == 0: augmented_df = augment_data(df) train_entity_id_mentions = make_entity_id_mentions(augmented_df) train_entity_id_name = make_entity_id_name(augmented_df) data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) random.shuffle(data) for x,y in batchGenerator(data, batch_size): # print(len(x), len(y), end='-->') optimizer.zero_grad() inputs = tokenizer(x, padding=True, return_tensors='pt') inputs.to(DEVICE) cls, logits = model( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'] ) # for training less than half the time, train on easy labels = y labels = [label2id[element] for element in labels] labels = torch.tensor(labels).to(DEVICE) y = torch.tensor(y).to(DEVICE) class_loss = F.cross_entropy(logits, labels) if epoch < epochs / 2: # triplet_loss, _ = batch_all_soft_margin_triplet_loss(y, cls, squared=False) # loss = class_loss + triplet_loss # loss,_ = batch_all_soft_margin_triplet_loss(y, cls, squared=False) loss = class_loss # for training after half the time, train on hard # else: # triplet_loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False) # loss = triplet_loss else: loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False) loss.backward() optimizer.step() total_loss += loss.detach().item() # total_cross += class_loss.detach().item() # total_triplet += triplet_loss.detach().item() batch_number += 1 # run evaluation on test data model.eval() with torch.no_grad(): run_evaluation(model=model.bert, tokenizer=tokenizer) model.train() # scheduler.step() # Update the learning rate # print(f'epoch loss: {total_loss/batch_number}, cross loss: {total_cross/batch_number}, triplet loss: {total_triplet/batch_number}') print(f'epoch loss: {total_loss/batch_number}') # print(f"Epoch {epoch+1}: lr={lr}") # if epoch % 5 == 0: # # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') # torch.save(model.state_dict(), './checkpoint/hybrid.pt') # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') # torch.save(model.state_dict(), './checkpoint/hybrid.pt') # %%