domain_mapping/reference_code/character_bert_train.py

461 lines
16 KiB
Python
Raw Permalink Normal View History

# %%
import torch
import json
import random
import numpy as np
from transformers import BertTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import pandas as pd
import re
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
torch.set_float32_matmul_precision('high')
def set_seed(seed):
"""
Set the random seed for reproducibility.
"""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch GPU
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
set_seed(42)
# %%
SHUFFLES=1
AMPLIFY_FACTOR=1
CORRUPT=0.1
LEARNING_RATE=1e-5
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
MODEL_NAME = 'helboukkouri/character-bert'
# %%
with open("top1_curves/character_output.txt", "w") as f:
pass
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
return entity_sets
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
x.extend(t[0])
y.extend([t[1]]*len(t[0]))
yield x, y
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
###############
# alternate data import strategy
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def generate_random_shuffles(text, n):
words = text.split() # Split the input into words
shuffled_variations = []
for _ in range(n):
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
return shuffled_variations
def shuffle_text(text, n_shuffles=SHUFFLES):
all_processed = []
# add the original text
all_processed.append(text)
# Generate random shuffles
shuffled_variations = generate_random_shuffles(text, n_shuffles)
all_processed.extend(shuffled_variations)
return all_processed
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_string(sentence, corruption_probability=0.01):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < corruption_probability else word
for word in words
]
return " ".join(corrupted_words)
def create_example(index, mention, entity_name):
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
# augment whole dataset
def augment_data(df):
output_list = []
for idx,row in df.iterrows():
index = row['entity_id']
entity_name = row['entity_name']
parent_desc = row['mention']
parent_desc = preprocess_text(parent_desc)
# add basic example
output_list.append(create_example(index, parent_desc, entity_name))
# # add shuffled strings
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
# for desc in processed_descs:
# if (desc != parent_desc):
# output_list.append(create_example(index, desc, entity_name))
# add corrupted strings
desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# add example with stripped non-alphanumerics
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# # short sequence amplifier
# # short sequences are rare, and we must compensate by including more examples
# # also, short sequence don't usually get affected by shuffle
# words = parent_desc.split()
# word_count = len(words)
# if word_count <= 2:
# for _ in range(AMPLIFY_FACTOR):
# output_list.append(create_example(index, desc, entity_name))
new_df = pd.DataFrame(output_list)
return new_df
# def sample_from_df(df, sample_size_per_class=5):
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
# .reset_index(drop=True))
#
# return sampled_df
# %%
def make_entity_id_mentions(df):
entity_id_mentions = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
return entity_id_mentions
def make_entity_id_name(df):
entity_id_name = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
# entity_id always matches entity_name, so first value would work
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
return entity_id_name
# %%
# evaluation
def run_evaluation(model, tokenizer):
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
with open('../esAppMod/tca_entities.json', 'r') as file:
eval_entities = json.load(file)
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
eval_train = json.load(file)
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
with open('../esAppMod/infer.json', 'r') as file:
eval_test = json.load(file)
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(eval_train_entities, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
output = outputs.last_hidden_state[:,0,:]
output = output.detach().cpu().numpy()
embedding_list.append(output)
cls = np.concatenate(embedding_list)
batches = batch_list(x_test, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
output = outputs.last_hidden_state[:,0,:]
output = output.detach().cpu().numpy()
embedding_list.append(output)
cls_test = np.concatenate(embedding_list)
knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, eval_labels)
with open("top1_curves/baseline_output.txt", "a") as f:
# only compute top-1
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
num = 0
for a,b in zip(y_test, indices):
b = [eval_labels[i] for i in b]
if a in b:
num += 1
print(f'{num / len(y_test)}', file=f)
# %%
class CharacterTransformer(nn.Module):
def __init__(self, num_chars, d_model=512, nhead=8, num_encoder_layers=6):
super(CharacterTransformer, self).__init__()
self.char_embedding = nn.Embedding(num_chars, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
def forward(self, input):
# input: (batch_size, seq_len)
embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model)
# embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model)
output = self.transformer_encoder(embeddings)
# output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model)
return output
class ASCIITokenizer:
def __init__(self):
# Initialize the tokenizer with ASCII characters.
# ASCII characters range from 0 to 127.
self.char_to_id = {chr(i): i for i in range(128)}
self.id_to_char = {i: chr(i) for i in range(128)}
def encode(self, text_list):
"""Encode a text string into a list of ASCII IDs."""
output_list = []
for text in text_list:
output = [self.char_to_id.get(char, None) for char in text if char in self.char_to_id]
output_list.append(output)
return output_list
def decode(self, ids_list):
"""Decode a list of ASCII IDs back into a text string."""
output_list = []
for ids in ids_list:
output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char)
output_list.append(output)
return output_list
# %%
tokenizer = ASCIITokenizer()
# Example text
text = ["Hello, world!", "Hello, world!"]
# Encode the text
encoded = tokenizer.encode(text)
print("Encoded:", encoded)
# Decode the encoded IDs
decoded = tokenizer.decode(encoded)
print("Decoded:", decoded)
# %%
# Example usage
model = CharacterTransformer(num_chars=128) # Assuming ASCII characters
input = torch.randint(0, 128, (10, 50)) # Example input tensor 10 sequences of 50 characters
output = model(input)
# %%
num_sample_per_class = 10 # samples in each group
batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
model.to(DEVICE)
model.train()
losses = []
for epoch in tqdm(range(epochs)):
total_loss = 0.0
batch_number = 0
if epoch % 1 == 0:
augmented_df = augment_data(df)
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
train_entity_id_name = make_entity_id_name(augmented_df)
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
random.shuffle(data)
for x,y in batchGenerator(data, batch_size):
# print(len(x), len(y), end='-->')
optimizer.zero_grad()
inputs = tokenizer(x, padding=True, return_tensors='pt')
inputs.to(DEVICE)
outputs = model(**inputs)
cls = outputs.last_hidden_state[:,0,:]
# for training less than half the time, train on easy
y = torch.tensor(y).to(DEVICE)
if epoch < epochs / 2:
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
# for training after half the time, train on hard
else:
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
loss.backward()
optimizer.step()
total_loss += loss.detach().item()
batch_number += 1
# del x, y, outputs, cls, loss
# torch.cuda.empty_cache()
epoch_loss = total_loss/batch_number
# scheduler.step(epoch_loss)
# run evaluation on test data
model.eval()
with torch.no_grad():
run_evaluation(model=model, tokenizer=tokenizer)
model.train()
# scheduler.step() # Update the learning rate
print(f'epoch loss: {epoch_loss}')
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
# if epoch == 125:
# torch.save(model.state_dict(), './checkpoint/baseline.pt')
# torch.save(model.state_dict(), './checkpoint/baseline.pt')
# %%