domain_mapping/experimental/character_bert_train.py

578 lines
20 KiB
Python

# %%
import torch
import json
import random
import numpy as np
from transformers import BertTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import pandas as pd
import re
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
torch.set_float32_matmul_precision('high')
def set_seed(seed):
"""
Set the random seed for reproducibility.
"""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch GPU
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
set_seed(42)
# %%
SHUFFLES=1
AMPLIFY_FACTOR=1
CORRUPT=0.00
LEARNING_RATE=1e-6
DEVICE = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
# %%
EVAL_FILE="top1_curves/batch_output.txt"
with open(EVAL_FILE, "w") as f:
pass
EVAL_FILE_KNN="top1_curves/batch_knn.txt"
with open(EVAL_FILE_KNN, "w") as f:
pass
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
return entity_sets
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
x.extend(t[0])
y.extend([t[1]]*len(t[0]))
yield x, y
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
###############
# alternate data import strategy
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def generate_random_shuffles(text, n):
words = text.split() # Split the input into words
shuffled_variations = []
for _ in range(n):
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
return shuffled_variations
def shuffle_text(text, n_shuffles=SHUFFLES):
all_processed = []
# add the original text
all_processed.append(text)
# Generate random shuffles
shuffled_variations = generate_random_shuffles(text, n_shuffles)
all_processed.extend(shuffled_variations)
return all_processed
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_string(sentence, corruption_probability=0.01):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < corruption_probability else word
for word in words
]
return " ".join(corrupted_words)
def create_example(index, mention, entity_name):
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
# augment whole dataset
def augment_data(df):
output_list = []
for idx,row in df.iterrows():
index = row['entity_id']
entity_name = row['entity_name']
parent_desc = row['mention']
parent_desc = preprocess_text(parent_desc)
# add basic example
output_list.append(create_example(index, parent_desc, entity_name))
# # add shuffled strings
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
# for desc in processed_descs:
# if (desc != parent_desc):
# output_list.append(create_example(index, desc, entity_name))
# add corrupted strings
desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# add example with stripped non-alphanumerics
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# # short sequence amplifier
# # short sequences are rare, and we must compensate by including more examples
# # also, short sequence don't usually get affected by shuffle
# words = parent_desc.split()
# word_count = len(words)
# if word_count <= 2:
# for _ in range(AMPLIFY_FACTOR):
# output_list.append(create_example(index, desc, entity_name))
new_df = pd.DataFrame(output_list)
return new_df
# def sample_from_df(df, sample_size_per_class=5):
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
# .reset_index(drop=True))
#
# return sampled_df
# %%
def make_entity_id_mentions(df):
entity_id_mentions = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
return entity_id_mentions
def make_entity_id_name(df):
entity_id_name = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
# entity_id always matches entity_name, so first value would work
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
return entity_id_name
# %%
# evaluation
def run_evaluation_logit(model, tokenizer):
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
with open('../esAppMod/tca_entities.json', 'r') as file:
eval_entities = json.load(file)
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
eval_train = json.load(file)
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
with open('../esAppMod/infer.json', 'r') as file:
eval_test = json.load(file)
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(x_test, 64)
pred_labels = []
for batch in batches:
# Inference in batches
inputs, attn_mask = tokenizer.encode(batch)
inputs = inputs.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)
with torch.no_grad():
_, logits = model(inputs, attn_mask)
predicted_class_ids = logits.argmax(dim=1).to("cpu")
pred_labels.extend(predicted_class_ids)
pred_labels = [tensor.item() for tensor in pred_labels]
# %%
labels = [label2id[element] for element in y_test]
with open(EVAL_FILE, "a") as f:
# only compute top-1
accuracy = accuracy_score(labels, pred_labels)
print(f'{accuracy}', file=f)
def run_evaluation_knn(model, tokenizer):
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
with open('../esAppMod/tca_entities.json', 'r') as file:
eval_entities = json.load(file)
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
eval_train = json.load(file)
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
with open('../esAppMod/infer.json', 'r') as file:
eval_test = json.load(file)
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(eval_train_entities, 64)
embedding_list = []
for batch in batches:
inputs, attn_mask = tokenizer.encode(batch)
inputs = inputs.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)
outputs = model(inputs, attn_mask)
output_slice = outputs[:,0,:]
output_slice = output_slice.detach().cpu().numpy()
embedding_list.append(output_slice)
cls = np.concatenate(embedding_list)
batches = batch_list(x_test, 64)
embedding_list = []
for batch in batches:
inputs, attn_mask = tokenizer.encode(batch)
inputs = inputs.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)
outputs = model(inputs, attn_mask)
output_slice = outputs[:,0,:]
output_slice = output_slice.detach().cpu().numpy()
embedding_list.append(output_slice)
cls_test = np.concatenate(embedding_list)
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels)
with open(EVAL_FILE_KNN, "a") as f:
# only compute top-1
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
num = 0
for a,b in zip(y_test, indices):
b = [eval_labels[i] for i in b]
if a in b:
num += 1
print(f'{num / len(y_test)}', file=f)
# %%
class CharacterTransformer(nn.Module):
def __init__(self, num_chars, d_model=256, nhead=4, num_encoder_layers=4):
super(CharacterTransformer, self).__init__()
self.char_embedding = nn.Embedding(num_chars, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
def forward(self, input, attention_mask):
# input: (batch_size, seq_len)
embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model)
# embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model)
output = self.transformer_encoder(embeddings, src_key_padding_mask=attention_mask)
# output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model)
return output
class ASCIITokenizer:
def __init__(self, pad_token='\0'):
# Initialize the tokenizer with ASCII characters.
# ASCII characters range from 0 to 127.
self.char_to_id = {chr(i): i for i in range(128)}
self.id_to_char = {i: chr(i) for i in range(128)}
self.pad_token = pad_token
def encode(self, text_list):
"""Encode a text string into a list of ASCII IDs and generate attention masks."""
output_list = []
max_length = 0
# First pass to find the maximum length and encode the texts
for text in text_list:
text = self.pad_token + text # Prepend pad_token to each text
output = [self.char_to_id.get(char, self.pad_token) for char in text]
output_list.append(output)
if len(output) > max_length:
max_length = len(output)
# Second pass to pad the sequences to the maximum length and create masks
padded_list = []
attention_masks = []
for output in output_list:
# we cannot mask the first token
attention_mask = [0] + [0] * (len(output) - 1) + [1] * (max_length - len(output)) # 1s for real tokens, 0s for padding
output = self.pad(output, max_length)
padded_list.append(output)
attention_masks.append(attention_mask)
return torch.tensor(padded_list, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.bool)
def decode(self, ids_list):
"""Decode a list of ASCII IDs back into a text string."""
output_list = []
for ids in ids_list:
output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char)
output_list.append(output)
return output_list
def pad(self, output, max_length):
"""Pad the output list with ASCII ID for space or another padding character to the maximum length."""
return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output))
# %%
tokenizer = ASCIITokenizer()
# # Example text
# text = ["Hello, world! This is cool", "Hello, world!"]
# # Encode the text
# encoded = tokenizer.encode(text)
# print("Encoded:", encoded)
#
# # Decode the encoded IDs
# decoded = tokenizer.decode(encoded.numpy())
# print("Decoded:", decoded)
# %%
# Example usage
bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters
class BertForClassificationAndTriplet(nn.Module):
def __init__(self, bert_model, num_classes):
super().__init__()
self.bert = bert_model
self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes)
def forward(self, input_ids, attention_mask=None):
outputs = self.bert(input_ids, attention_mask)
cls_embeddings = outputs[:, 0, :] # CLS token
logits = self.classifier(cls_embeddings)
return cls_embeddings, logits
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
# %%
num_sample_per_class = 10 # samples in each group
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# num_warmup_steps=100
# total_steps = epochs * (1126/64)
# scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
# %%
state_dict = torch.load('./checkpoint/pretrained_character_bert.pt')
state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()}
model.load_state_dict(state_dict)
model.to(DEVICE)
model.train()
losses = []
for epoch in tqdm(range(epochs)):
total_loss = 0.0
batch_number = 0
if epoch % 10 == 0:
augmented_df = augment_data(df)
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
train_entity_id_name = make_entity_id_name(augmented_df)
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
random.shuffle(data)
for x,y in batchGenerator(data, batch_size):
# print(len(x), len(y), end='-->')
optimizer.zero_grad()
inputs, attn_mask = tokenizer.encode(x)
inputs = inputs.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)
cls, logits = model(inputs, attn_mask)
# labels = y
# labels = [label2id[element] for element in labels]
# labels = torch.tensor(labels).to(DEVICE)
# loss = F.cross_entropy(logits, labels)
# for training less than half the time, train on easy
y = torch.tensor(y).to(DEVICE)
if epoch < epochs / 2:
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
# for training after half the time, train on hard
else:
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
loss.backward()
# scheduler.step()
optimizer.step()
total_loss += loss.detach().item()
batch_number += 1
# del x, y, outputs, cls, loss
# torch.cuda.empty_cache()
epoch_loss = total_loss/batch_number
# scheduler.step() # Update the learning rate
print(f'epoch loss: {epoch_loss}')
if (epoch % 1 == 0):
model.eval()
with torch.no_grad():
run_evaluation_logit(model=model, tokenizer=tokenizer)
run_evaluation_knn(model=model.bert, tokenizer=tokenizer)
# run evaluation on test data
model.train()
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
if (epoch % 100 == 0) and (epoch > 100):
torch.save(model.state_dict(), './checkpoint/character_bert.pt')
torch.save(model.state_dict(), './checkpoint/character_bert_final.pt')
# %%