# %% import torch import json import random import numpy as np from transformers import AutoTokenizer from transformers import AutoModel from loss import batch_all_triplet_loss, batch_hard_triplet_loss from sklearn.neighbors import KNeighborsClassifier from tqdm import tqdm import pandas as pd import re from torch.utils.data import Dataset, DataLoader import torch.optim as optim # %% SHUFFLES=0 AMPLIFY_FACTOR=0 LEARNING_RATE=1e-5 # %% def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): # split entity mentions into groups # anchor = False, don't add entity name to each group, simply treat it as a normal mention entity_sets = [] if anchor: for id, mentions in entity_id_mentions.items(): random.shuffle(mentions) positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] entity_sets.extend(anchor_positive) else: for id, mentions in entity_id_mentions.items(): group = list(set([entity_id_name[id]] + mentions)) random.shuffle(group) positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] entity_sets.extend(positives) return entity_sets def batchGenerator(data, batch_size): for i in range(0, len(data), batch_size): batch = data[i:i+batch_size] x, y = [], [] for t in batch: x.extend(t[0]) y.extend([t[1]]*len(t[0])) yield x, y with open('../esAppMod/tca_entities.json', 'r') as file: entities = json.load(file) all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} with open('../esAppMod/train.json', 'r') as file: train = json.load(file) train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} # %% ############### # alternate data import strategy ################################################### # import code # import training file data_path = '../esAppMod_data_import/train.csv' df = pd.read_csv(data_path, skipinitialspace=True) # rather than use pattern, we use the real thing and property entity_ids = df['entity_id'].to_list() target_id_list = sorted(list(set(entity_ids))) id2label = {} label2id = {} for idx, val in enumerate(target_id_list): id2label[idx] = val label2id[val] = idx df["training_id"] = df["entity_id"].map(label2id) # %% ############################################################## # augmentation code # basic preprocessing def preprocess_text(text): # 1. Make all uppercase text = text.lower() # standardize spacing text = re.sub(r'\s+', ' ', text).strip() return text def generate_random_shuffles(text, n): words = text.split() # Split the input into words shuffled_variations = [] for _ in range(n): shuffled = words[:] # Copy the word list to avoid in-place modification random.shuffle(shuffled) # Randomly shuffle the words shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string return shuffled_variations def shuffle_text(text, n_shuffles=SHUFFLES): all_processed = [] # add the original text all_processed.append(text) # Generate random shuffles shuffled_variations = generate_random_shuffles(text, n_shuffles) all_processed.extend(shuffled_variations) return all_processed def corrupt_word(word): """Corrupt a single word using random corruption techniques.""" if len(word) <= 1: # Skip corruption for single-character words return word corruption_type = random.choice(["delete", "swap"]) if corruption_type == "delete": # Randomly delete a character idx = random.randint(0, len(word) - 1) word = word[:idx] + word[idx + 1:] elif corruption_type == "swap": # Swap two adjacent characters if len(word) > 1: idx = random.randint(0, len(word) - 2) word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) return word def corrupt_string(sentence, corruption_probability=0.01): """Corrupt each word in the string with a given probability.""" words = sentence.split() corrupted_words = [ corrupt_word(word) if random.random() < corruption_probability else word for word in words ] return " ".join(corrupted_words) def create_example(index, mention, entity_name): return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} # augment whole dataset def augment_data(df): output_list = [] for idx,row in df.iterrows(): index = row['entity_id'] entity_name = row['entity_name'] parent_desc = row['mention'] parent_desc = preprocess_text(parent_desc) # add basic example output_list.append(create_example(index, parent_desc, entity_name)) # add shuffled strings processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) for desc in processed_descs: if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add corrupted strings desc = corrupt_string(parent_desc, corruption_probability=0.01) if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add example with stripped non-alphanumerics desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # short sequence amplifier # short sequences are rare, and we must compensate by including more examples # also, short sequence don't usually get affected by shuffle words = parent_desc.split() word_count = len(words) if word_count <= 2: for _ in range(AMPLIFY_FACTOR): output_list.append(create_example(index, desc, entity_name)) new_df = pd.DataFrame(output_list) return new_df # %% def make_entity_id_mentions(df): entity_id_mentions = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() return entity_id_mentions def make_entity_id_name(df): entity_id_name = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: # entity_id always matches entity_name, so first value would work entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] return entity_id_name # %% num_sample_per_class = 10 # samples in each group batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = 2 epochs = 200 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) model.to(DEVICE) model.train() losses = [] for epoch in tqdm(range(epochs)): total_loss = 0.0 batch_number = 0 augmented_df = augment_data(df) train_entity_id_mentions = make_entity_id_mentions(augmented_df) train_entity_id_name = make_entity_id_name(augmented_df) data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) random.shuffle(data) for x,y in batchGenerator(data, batch_size): # print(len(x), len(y), end='-->') optimizer.zero_grad() inputs = tokenizer(x, padding=True, return_tensors='pt') inputs.to(DEVICE) outputs = model(**inputs) cls = outputs.last_hidden_state[:,0,:] # for training less than half the time, train on easy y = torch.tensor(y).to(DEVICE) if epoch < epochs / 2: loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) # for training after half the time, train on hard else: loss = batch_hard_triplet_loss(y, cls, margin, squared=False) loss.backward() optimizer.step() total_loss += loss.detach().item() batch_number += 1 del x, y, outputs, cls, loss torch.cuda.empty_cache() # scheduler.step() # Update the learning rate print(f'epoch loss: {total_loss/batch_number}') # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") if epoch % 5 == 0: torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') torch.save(model.state_dict(), './checkpoint/siamese_simple.pt') # %%