# %% import torch import json import random import numpy as np from transformers import AutoTokenizer from transformers import AutoModel # from loss import batch_all_triplet_loss, batch_hard_triplet_loss import loss from sklearn.neighbors import KNeighborsClassifier from tqdm import tqdm import pandas as pd import re from torch.utils.data import Dataset, DataLoader import torch.optim as optim import torch.nn.functional as F # %% SHUFFLES=0 AMPLIFY_FACTOR=0 LEARNING_RATE=1e-5 DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') # MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' # %% def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): # split entity mentions into groups # anchor = False, don't add entity name to each group, simply treat it as a normal mention entity_sets = [] if anchor: for id, mentions in entity_id_mentions.items(): random.shuffle(mentions) positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] entity_sets.extend(anchor_positive) else: for id, mentions in entity_id_mentions.items(): group = list(set([entity_id_name[id]] + mentions)) random.shuffle(group) positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] entity_sets.extend(positives) return entity_sets def batchGenerator(data, batch_size): for i in range(0, len(data), batch_size): batch = data[i:i+batch_size] x, y = [], [] for t in batch: x.extend(t[0]) y.extend([t[1]]*len(t[0])) yield x, y with open('../esAppMod/tca_entities.json', 'r') as file: entities = json.load(file) all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} with open('../esAppMod/train.json', 'r') as file: train = json.load(file) train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} # %% ############### # alternate data import strategy ################################################### # import code # import training file data_path = '../esAppMod_data_import/train.csv' df = pd.read_csv(data_path, skipinitialspace=True) # rather than use pattern, we use the real thing and property entity_ids = df['entity_id'].to_list() target_id_list = sorted(list(set(entity_ids))) id2label = {} label2id = {} for idx, val in enumerate(target_id_list): id2label[idx] = val label2id[val] = idx df["training_id"] = df["entity_id"].map(label2id) # %% ############################################################## # augmentation code # basic preprocessing def preprocess_text(text): # 1. Make all uppercase text = text.lower() # standardize spacing text = re.sub(r'\s+', ' ', text).strip() return text def generate_random_shuffles(text, n): words = text.split() # Split the input into words shuffled_variations = [] for _ in range(n): shuffled = words[:] # Copy the word list to avoid in-place modification random.shuffle(shuffled) # Randomly shuffle the words shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string return shuffled_variations def shuffle_text(text, n_shuffles=SHUFFLES): all_processed = [] # add the original text all_processed.append(text) # Generate random shuffles shuffled_variations = generate_random_shuffles(text, n_shuffles) all_processed.extend(shuffled_variations) return all_processed def corrupt_word(word): """Corrupt a single word using random corruption techniques.""" if len(word) <= 1: # Skip corruption for single-character words return word corruption_type = random.choice(["delete", "swap"]) if corruption_type == "delete": # Randomly delete a character idx = random.randint(0, len(word) - 1) word = word[:idx] + word[idx + 1:] elif corruption_type == "swap": # Swap two adjacent characters if len(word) > 1: idx = random.randint(0, len(word) - 2) word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) return word def corrupt_string(sentence, corruption_probability=0.01): """Corrupt each word in the string with a given probability.""" words = sentence.split() corrupted_words = [ corrupt_word(word) if random.random() < corruption_probability else word for word in words ] return " ".join(corrupted_words) def create_example(index, mention, entity_name): return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} # augment whole dataset def augment_data(df): output_list = [] for idx,row in df.iterrows(): index = row['entity_id'] entity_name = row['entity_name'] parent_desc = row['mention'] parent_desc = preprocess_text(parent_desc) # add basic example output_list.append(create_example(index, parent_desc, entity_name)) # add shuffled strings processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) for desc in processed_descs: if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add corrupted strings desc = corrupt_string(parent_desc, corruption_probability=0.01) if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add example with stripped non-alphanumerics desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # short sequence amplifier # short sequences are rare, and we must compensate by including more examples # also, short sequence don't usually get affected by shuffle words = parent_desc.split() word_count = len(words) if word_count <= 2: for _ in range(AMPLIFY_FACTOR): output_list.append(create_example(index, desc, entity_name)) new_df = pd.DataFrame(output_list) return new_df # %% def make_entity_id_mentions(df): entity_id_mentions = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() return entity_id_mentions def make_entity_id_name(df): entity_id_name = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: # entity_id always matches entity_name, so first value would work entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] return entity_id_name # %% num_sample_per_class = 10 # samples in each group batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = 2 epochs = 200 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) model.to(DEVICE) model.train() losses = [] # %% augmented_df = augment_data(df) train_entity_id_mentions = make_entity_id_mentions(augmented_df) train_entity_id_name = make_entity_id_name(augmented_df) data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) random.shuffle(data) # %% x, y = next(iter(batchGenerator(data, batch_size))) # %% inputs = tokenizer(x, padding=True, return_tensors='pt') inputs.to(DEVICE) outputs = model(**inputs) cls = outputs.last_hidden_state[:,0,:] # for training less than half the time, train on easy y = torch.tensor(y).to(DEVICE) # %% def _pairwise_distances(embeddings, squared=False): """Compute the 2D matrix of distances between all the embeddings. Args: embeddings: tensor of shape (batch_size, embed_dim) squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. If false, output is the pairwise euclidean distance matrix. Returns: pairwise_distances: tensor of shape (batch_size, batch_size) """ dot_product = torch.matmul(embeddings, embeddings.t()) # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. # This also provides more numerical stability (the diagonal of the result will be exactly 0). # shape (batch_size,) square_norm = torch.diag(dot_product) # Compute the pairwise distance matrix as we have: # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 # shape (batch_size, batch_size) distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) # Apply a lower bound to distances to ensure they are non-negative and avoid tiny negative numbers due to computation errors distances = torch.clamp(distances, min=0.0) if not squared: # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) # we need to add a small epsilon where distances == 0.0 epsilon = 1e-16 mask = (distances < epsilon).float() distances = distances + mask * epsilon distances = (1.0 - mask) * torch.sqrt(distances) return distances # %% embeddings = cls squared = False # %% # Get the pairwise distance matrix pairwise_dist = loss._pairwise_distances(embeddings, squared=squared) # 96x96 anchor_positive_dist = pairwise_dist.unsqueeze(2) # 96x96x1 anchor_negative_dist = pairwise_dist.unsqueeze(1) # 96x1x96 # Compute a 3D tensor of size (batch_size, batch_size, batch_size) # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k # every (i,j) pairwise distance - every (i,k) pairwise distance # fixing for i, we get (i,j) - (i,k), for every j and k, which is 96x96 # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) # and the 2nd (batch_size, 1, batch_size) # remember that broadcasting is repeating the other axis n-times # this broadcasting trick is to get every possible triple combination triplet_loss = anchor_positive_dist - anchor_negative_dist + margin # triplet_loss 96x96x96 # %% labels = y # %% # Put to zero the invalid triplets # (where label(a) != label(p) or label(n) == label(a) or a == p) mask = loss._get_triplet_mask(labels) triplet_loss = mask.float() * triplet_loss # Remove negative losses (i.e. the easy triplets) triplet_loss = F.relu(triplet_loss) # Count number of positive triplets (where triplet_loss > 0) valid_triplets = triplet_loss[triplet_loss > 1e-16] num_positive_triplets = valid_triplets.size(0) num_valid_triplets = mask.sum() fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) # Get final mean triplet loss over the positive valid triplets triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) # %% # %% loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) # %% loss = batch_hard_triplet_loss(y, cls, margin, squared=False) # %% # Check that i, j and k are distinct # create an identity matrix of size 96 indices_equal = torch.eye(labels.size(0), device=labels.device).bool() # %% indices_not_equal = ~indices_equal i_not_equal_j = indices_not_equal.unsqueeze(2) # [96,96,1] i_not_equal_k = indices_not_equal.unsqueeze(1) # [96,1,96] j_not_equal_k = indices_not_equal.unsqueeze(0) # [1,96,96] # %% # eliminate any combination that uses the diagonal values (aka sharing same values) distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k # %% label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # label_equal is a 96x96 matrix showing where 2 labels equate # perform the same unsqueeze to 1 and 2 axis and broadcast to get all possible combinations # note that we have 96 elements, but we want all (i,j,k) combinations from these 96 elements i_equal_j = label_equal.unsqueeze(2) i_equal_k = label_equal.unsqueeze(1) # ~i_equal_k means that it checks for non-equality between i and k # i_equal_j checks for equality between i and j # we want (i,j) to be the same label, (i,k) to be different labels valid_labels = ~i_equal_k & i_equal_j # %% final_mask = distinct_indices & valid_labels