# %% import torch import json import random import numpy as np from transformers import AutoTokenizer from transformers import AutoModel from loss import batch_all_triplet_loss, batch_hard_triplet_loss from sklearn.neighbors import KNeighborsClassifier from tqdm import tqdm import pandas as pd import re from torch.utils.data import Dataset, DataLoader import torch.optim as optim import torch.nn as nn import torch.nn.functional as F # %% SHUFFLES=0 AMPLIFY_FACTOR=2 LEARNING_RATE=1e-5 # %% def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True): # split entity mentions into groups # anchor = False, don't add entity name to each group, simply treat it as a normal mention entity_sets = [] if anchor: for id, mentions in entity_id_mentions.items(): random.shuffle(mentions) positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)] anchor_positive = [([entity_id_name[id]]+p, id) for p in positives] entity_sets.extend(anchor_positive) else: for id, mentions in entity_id_mentions.items(): group = list(set([entity_id_name[id]] + mentions)) random.shuffle(group) positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)] entity_sets.extend(positives) return entity_sets def batchGenerator(data, batch_size): for i in range(0, len(data), batch_size): batch = data[i:i+batch_size] x, y = [], [] for t in batch: x.extend(t[0]) y.extend([t[1]]*len(t[0])) yield x, y with open('../esAppMod/tca_entities.json', 'r') as file: entities = json.load(file) all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()} with open('../esAppMod/train.json', 'r') as file: train = json.load(file) train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()} train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()} # %% ############### # alternate data import strategy ################################################### # import code # import training file data_path = '../esAppMod_data_import/train.csv' df = pd.read_csv(data_path, skipinitialspace=True) # rather than use pattern, we use the real thing and property entity_ids = df['entity_id'].to_list() target_id_list = sorted(list(set(entity_ids))) id2label = {} label2id = {} for idx, val in enumerate(target_id_list): id2label[idx] = val label2id[val] = idx df["training_id"] = df["entity_id"].map(label2id) # %% ############################################################## # augmentation code # basic preprocessing def preprocess_text(text): # 1. Make all uppercase text = text.lower() # standardize spacing text = re.sub(r'\s+', ' ', text).strip() return text def generate_random_shuffles(text, n): words = text.split() # Split the input into words shuffled_variations = [] for _ in range(n): shuffled = words[:] # Copy the word list to avoid in-place modification random.shuffle(shuffled) # Randomly shuffle the words shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string return shuffled_variations def shuffle_text(text, n_shuffles=SHUFFLES): all_processed = [] # add the original text all_processed.append(text) # Generate random shuffles shuffled_variations = generate_random_shuffles(text, n_shuffles) all_processed.extend(shuffled_variations) return all_processed def corrupt_word(word): """Corrupt a single word using random corruption techniques.""" if len(word) <= 1: # Skip corruption for single-character words return word corruption_type = random.choice(["delete", "swap"]) if corruption_type == "delete": # Randomly delete a character idx = random.randint(0, len(word) - 1) word = word[:idx] + word[idx + 1:] elif corruption_type == "swap": # Swap two adjacent characters if len(word) > 1: idx = random.randint(0, len(word) - 2) word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]) return word def corrupt_string(sentence, corruption_probability=0.01): """Corrupt each word in the string with a given probability.""" words = sentence.split() corrupted_words = [ corrupt_word(word) if random.random() < corruption_probability else word for word in words ] return " ".join(corrupted_words) def create_example(index, mention, entity_name): return {'entity_id': index, 'mention': mention, 'entity_name': entity_name} # augment whole dataset def augment_data(df): output_list = [] for idx,row in df.iterrows(): index = row['entity_id'] entity_name = row['entity_name'] parent_desc = row['mention'] parent_desc = preprocess_text(parent_desc) # add basic example output_list.append(create_example(index, parent_desc, entity_name)) # add shuffled strings processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES) for desc in processed_descs: if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add corrupted strings desc = corrupt_string(parent_desc, corruption_probability=0.01) if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # add example with stripped non-alphanumerics desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces if (desc != parent_desc): output_list.append(create_example(index, desc, entity_name)) # short sequence amplifier # short sequences are rare, and we must compensate by including more examples # also, short sequence don't usually get affected by shuffle words = parent_desc.split() word_count = len(words) if word_count <= 2: for _ in range(AMPLIFY_FACTOR): output_list.append(create_example(index, desc, entity_name)) new_df = pd.DataFrame(output_list) return new_df # %% def make_entity_id_mentions(df): entity_id_mentions = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list() return entity_id_mentions def make_entity_id_name(df): entity_id_name = {} entity_id_list = list(set(df['entity_id'])) for entity_id in entity_id_list: # entity_id always matches entity_name, so first value would work entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0] return entity_id_name # %% num_sample_per_class = 10 # samples in each group batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class margin = 2 epochs = 200 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased' MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) bert_model = AutoModel.from_pretrained(MODEL_NAME) class BertForClassificationAndTriplet(nn.Module): def __init__(self, bert_model, num_classes): super().__init__() self.bert = bert_model self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes) def forward(self, input_ids, attention_mask=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token logits = self.classifier(cls_embeddings) return cls_embeddings, logits model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id)) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) model.to(DEVICE) model.train() losses = [] for epoch in tqdm(range(epochs)): total_loss = 0.0 batch_number = 0 augmented_df = augment_data(df) train_entity_id_mentions = make_entity_id_mentions(augmented_df) train_entity_id_name = make_entity_id_name(augmented_df) data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True) random.shuffle(data) for x,y in batchGenerator(data, batch_size): # print(len(x), len(y), end='-->') optimizer.zero_grad() inputs = tokenizer(x, padding=True, return_tensors='pt') inputs.to(DEVICE) cls, logits = model( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'] ) # for training less than half the time, train on easy labels = y labels = [label2id[element] for element in labels] labels = torch.tensor(labels).to(DEVICE) y = torch.tensor(y).to(DEVICE) class_loss = F.cross_entropy(logits, labels) if epoch < epochs / 2: triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False) # for training after half the time, train on hard else: triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False) loss = class_loss + triplet_loss loss.backward() optimizer.step() total_loss += loss.detach().item() batch_number += 1 del x, y, cls, logits, loss torch.cuda.empty_cache() # scheduler.step() # Update the learning rate print(f'epoch loss: {total_loss/batch_number}') # print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}") if epoch % 5 == 0: # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') torch.save(model.state_dict(), './checkpoint/classification.pt') # torch.save(model.bert.state_dict(), './checkpoint/classification.pt') torch.save(model.state_dict(), './checkpoint/classification.pt') # %%