triplet loss with classification as a regularizer

- best record of 82.57 for esAppMod
This commit is contained in:
Richard Wong 2025-01-18 23:53:08 +09:00
parent 94ee7beba7
commit ac340f6fd2
16 changed files with 4578 additions and 14 deletions

2
cosines_with_augmentations/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__
checkpoint

View File

@ -0,0 +1,256 @@
# %%
# from datasets import load_from_disk
import os
import glob
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import re
import torch
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModel,
DataCollatorWithPadding,
)
import evaluate
import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
from tqdm import tqdm
torch.set_float32_matmul_precision('high')
BATCH_SIZE = 256
# %%
# construct the target id list
# data_path = '../../../esAppMod_data_import/train.csv'
data_path = '../esAppMod_data_import/train.csv'
train_df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = train_df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
# %%
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
# introduce pre-processing functions
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# Substitute digits with '#'
# text = re.sub(r'\d+', '#', text)
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
# outputs a list of dictionaries
# processes dataframe into lists of dictionaries
# each element maps input to output
# input: tag_description
# output: class label
def process_df_to_dict(df):
output_list = []
for _, row in df.iterrows():
desc = row['mention']
desc = preprocess_text(desc)
index = row['entity_id']
element = {
'text' : desc,
'label': label2id[index], # ensure labels starts from 0
}
output_list.append(element)
return output_list
def create_dataset():
# train
data_path = '../esAppMod_data_import/test.csv'
test_df = pd.read_csv(data_path, skipinitialspace=True)
# combined_data = DatasetDict({
# 'train': Dataset.from_list(process_df_to_dict(train_df)),
# })
return Dataset.from_list(process_df_to_dict(test_df))
# %%
def test():
test_dataset = create_dataset()
# prepare tokenizer
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True)
# Define additional special tokens
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
# Add the additional special tokens to the tokenizer
# tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
# %%
# compute max token length
max_length = 0
for sample in test_dataset['text']:
# Tokenize the sample and get the length
input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"]
length = len(input_ids)
# Update max_length if this sample is longer
if length > max_length:
max_length = length
print(max_length)
# %%
max_length = 128
# given a dataset entry, run it through the tokenizer
def preprocess_function(example):
input = example['text']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
# max_length=max_length,
# padding='max_length'
)
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
datasets = test_dataset.map(
preprocess_function,
batched=True,
num_proc=8,
remove_columns="text",
)
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = AutoModel.from_pretrained(MODEL_NAME)
class BertForClassificationAndTriplet(nn.Module):
def __init__(self, bert_model, num_classes):
super().__init__()
self.bert = bert_model
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
logits = self.classifier(cls_embeddings)
return cls_embeddings, logits
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
state_dict = torch.load('./checkpoint/classification.pt')
model.load_state_dict(state_dict)
model = model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
pred_labels = []
actual_labels = []
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
for batch in tqdm(dataloader):
# Inference in batches
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
# save labels too
actual_labels.extend(batch['labels'])
# Move to GPU if available
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Perform inference
with torch.no_grad():
cls, logits = model(
input_ids,
attention_mask)
predicted_class_ids = logits.argmax(dim=1).to("cpu")
pred_labels.extend(predicted_class_ids)
pred_labels = [tensor.item() for tensor in pred_labels]
# %%
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
y_true = actual_labels
y_pred = pred_labels
# Compute metrics
accuracy = accuracy_score(y_true, y_pred)
average_parameter = 'weighted'
zero_division_parameter = 0
f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
with open("results/output.txt", "a") as f:
print('*' * 80, file=f)
# Print the results
print(f'Accuracy: {accuracy:.5f}', file=f)
print(f'F1 Score: {f1:.5f}', file=f)
print(f'Precision: {precision:.5f}', file=f)
print(f'Recall: {recall:.5f}', file=f)
# export result
label_list = [id2label[id] for id in pred_labels]
df = pd.DataFrame({
'class_prediction': pd.Series(label_list)
})
# we can save the t5 generation output here
df.to_csv(f"results/classify.csv", index=False)
# %%
# reset file before writing to it
with open("results/output.txt", "w") as f:
print('', file=f)
test()

View File

@ -0,0 +1,124 @@
# %%
import torch
import json
import random
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import re
import gc
# %%
# Step 2: Load the state dictionary
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# state_dict = torch.load('./checkpoint/siamese.pt')
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
state_dict = torch.load('./checkpoint/classification.pt')
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
# %%
# Step 3: Apply the state dictionary to the model
model.load_state_dict(params_dict)
model.to(DEVICE)
model.eval()
# %%
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
with open('../esAppMod/infer.json', 'r') as file:
test = json.load(file)
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
y_test = [d['entity_id'] for _, d in test['data'].items()]
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
train_entities = [preprocess_text(element) for element in train_entities]
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(train_entities, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
output = outputs.last_hidden_state[:,0,:]
output = output.detach().cpu().numpy()
embedding_list.append(output)
cls = np.concatenate(embedding_list)
# %%
gc.collect()
torch.cuda.empty_cache()
# %%
batches = batch_list(x_test, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
output = outputs.last_hidden_state[:,0,:]
output = output.detach().cpu().numpy()
embedding_list.append(output)
cls_test = np.concatenate(embedding_list)
# %%
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
n_neighbors = [1, 3, 5, 10]
with open("results/output.txt", "w") as f:
for n in n_neighbors:
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
num = 0
for a,b in zip(y_test, indices):
b = [labels[i] for i in b]
if a in b:
num += 1
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
print(np.min(distances), np.max(distances), file=f)
# %%

View File

@ -0,0 +1,276 @@
# %%
import torch
import json
import random
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import pandas as pd
import re
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
# %%
SHUFFLES=0
AMPLIFY_FACTOR=0
LEARNING_RATE=1e-5
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
return entity_sets
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
x.extend(t[0])
y.extend([t[1]]*len(t[0]))
yield x, y
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
###############
# alternate data import strategy
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def generate_random_shuffles(text, n):
words = text.split() # Split the input into words
shuffled_variations = []
for _ in range(n):
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
return shuffled_variations
def shuffle_text(text, n_shuffles=SHUFFLES):
all_processed = []
# add the original text
all_processed.append(text)
# Generate random shuffles
shuffled_variations = generate_random_shuffles(text, n_shuffles)
all_processed.extend(shuffled_variations)
return all_processed
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_string(sentence, corruption_probability=0.01):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < corruption_probability else word
for word in words
]
return " ".join(corrupted_words)
def create_example(index, mention, entity_name):
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
# augment whole dataset
def augment_data(df):
output_list = []
for idx,row in df.iterrows():
index = row['entity_id']
entity_name = row['entity_name']
parent_desc = row['mention']
parent_desc = preprocess_text(parent_desc)
# add basic example
output_list.append(create_example(index, parent_desc, entity_name))
# add shuffled strings
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
for desc in processed_descs:
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# add corrupted strings
desc = corrupt_string(parent_desc, corruption_probability=0.01)
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# add example with stripped non-alphanumerics
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# short sequence amplifier
# short sequences are rare, and we must compensate by including more examples
# also, short sequence don't usually get affected by shuffle
words = parent_desc.split()
word_count = len(words)
if word_count <= 2:
for _ in range(AMPLIFY_FACTOR):
output_list.append(create_example(index, desc, entity_name))
new_df = pd.DataFrame(output_list)
return new_df
# %%
def make_entity_id_mentions(df):
entity_id_mentions = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
return entity_id_mentions
def make_entity_id_name(df):
entity_id_name = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
# entity_id always matches entity_name, so first value would work
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
return entity_id_name
# %%
num_sample_per_class = 10 # samples in each group
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
model.to(DEVICE)
model.train()
losses = []
for epoch in tqdm(range(epochs)):
total_loss = 0.0
batch_number = 0
augmented_df = augment_data(df)
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
train_entity_id_name = make_entity_id_name(augmented_df)
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
random.shuffle(data)
for x,y in batchGenerator(data, batch_size):
# print(len(x), len(y), end='-->')
optimizer.zero_grad()
inputs = tokenizer(x, padding=True, return_tensors='pt')
inputs.to(DEVICE)
outputs = model(**inputs)
cls = outputs.last_hidden_state[:,0,:]
# for training less than half the time, train on easy
y = torch.tensor(y).to(DEVICE)
if epoch < epochs / 2:
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
# for training after half the time, train on hard
else:
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
loss.backward()
optimizer.step()
total_loss += loss.detach().item()
batch_number += 1
del x, y, outputs, cls, loss
torch.cuda.empty_cache()
# scheduler.step() # Update the learning rate
print(f'epoch loss: {total_loss/batch_number}')
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
if epoch % 5 == 0:
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
# %%

View File

@ -0,0 +1,305 @@
# %%
import torch
import json
import random
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
import pandas as pd
import re
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
# %%
SHUFFLES=0
AMPLIFY_FACTOR=2
LEARNING_RATE=1e-5
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
entity_sets = []
if anchor:
for id, mentions in entity_id_mentions.items():
random.shuffle(mentions)
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
return entity_sets
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
x.extend(t[0])
y.extend([t[1]]*len(t[0]))
yield x, y
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
###############
# alternate data import strategy
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def generate_random_shuffles(text, n):
words = text.split() # Split the input into words
shuffled_variations = []
for _ in range(n):
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
return shuffled_variations
def shuffle_text(text, n_shuffles=SHUFFLES):
all_processed = []
# add the original text
all_processed.append(text)
# Generate random shuffles
shuffled_variations = generate_random_shuffles(text, n_shuffles)
all_processed.extend(shuffled_variations)
return all_processed
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_string(sentence, corruption_probability=0.01):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < corruption_probability else word
for word in words
]
return " ".join(corrupted_words)
def create_example(index, mention, entity_name):
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
# augment whole dataset
def augment_data(df):
output_list = []
for idx,row in df.iterrows():
index = row['entity_id']
entity_name = row['entity_name']
parent_desc = row['mention']
parent_desc = preprocess_text(parent_desc)
# add basic example
output_list.append(create_example(index, parent_desc, entity_name))
# add shuffled strings
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
for desc in processed_descs:
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# add corrupted strings
desc = corrupt_string(parent_desc, corruption_probability=0.01)
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# add example with stripped non-alphanumerics
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
if (desc != parent_desc):
output_list.append(create_example(index, desc, entity_name))
# short sequence amplifier
# short sequences are rare, and we must compensate by including more examples
# also, short sequence don't usually get affected by shuffle
words = parent_desc.split()
word_count = len(words)
if word_count <= 2:
for _ in range(AMPLIFY_FACTOR):
output_list.append(create_example(index, desc, entity_name))
new_df = pd.DataFrame(output_list)
return new_df
# %%
def make_entity_id_mentions(df):
entity_id_mentions = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
return entity_id_mentions
def make_entity_id_name(df):
entity_id_name = {}
entity_id_list = list(set(df['entity_id']))
for entity_id in entity_id_list:
# entity_id always matches entity_name, so first value would work
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
return entity_id_name
# %%
num_sample_per_class = 10 # samples in each group
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = AutoModel.from_pretrained(MODEL_NAME)
class BertForClassificationAndTriplet(nn.Module):
def __init__(self, bert_model, num_classes):
super().__init__()
self.bert = bert_model
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
logits = self.classifier(cls_embeddings)
return cls_embeddings, logits
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
model.to(DEVICE)
model.train()
losses = []
for epoch in tqdm(range(epochs)):
total_loss = 0.0
batch_number = 0
augmented_df = augment_data(df)
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
train_entity_id_name = make_entity_id_name(augmented_df)
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
random.shuffle(data)
for x,y in batchGenerator(data, batch_size):
# print(len(x), len(y), end='-->')
optimizer.zero_grad()
inputs = tokenizer(x, padding=True, return_tensors='pt')
inputs.to(DEVICE)
cls, logits = model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask']
)
# for training less than half the time, train on easy
labels = y
labels = [label2id[element] for element in labels]
labels = torch.tensor(labels).to(DEVICE)
y = torch.tensor(y).to(DEVICE)
class_loss = F.cross_entropy(logits, labels)
if epoch < epochs / 2:
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
# for training after half the time, train on hard
else:
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
loss = class_loss + triplet_loss
loss.backward()
optimizer.step()
total_loss += loss.detach().item()
batch_number += 1
del x, y, cls, logits, loss
torch.cuda.empty_cache()
# scheduler.step() # Update the learning rate
print(f'epoch loss: {total_loss/batch_number}')
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
if epoch % 5 == 0:
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
torch.save(model.state_dict(), './checkpoint/classification.pt')
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
torch.save(model.state_dict(), './checkpoint/classification.pt')
# %%

View File

@ -0,0 +1,186 @@
# stardard functionalities for computing triplet loss, borrow code from
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
import torch
import torch.nn.functional as F
def _pairwise_distances(embeddings, squared=False):
"""Compute the 2D matrix of distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
pairwise_distances: tensor of shape (batch_size, batch_size)
"""
dot_product = torch.matmul(embeddings, embeddings.t())
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
# shape (batch_size,)
square_norm = torch.diag(dot_product)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
# Because of computation errors, some distances might be negative so we put everything >= 0.0
distances[distances < 0] = 0
if not squared:
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = distances.eq(0).float()
distances = distances + mask * 1e-16
distances = (1.0 -mask) * torch.sqrt(distances)
return distances
def _get_triplet_mask(labels):
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
A triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
"""
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
j_not_equal_k = indices_not_equal.unsqueeze(0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
i_equal_j = label_equal.unsqueeze(2)
i_equal_k = label_equal.unsqueeze(1)
valid_labels = ~i_equal_k & i_equal_j
return valid_labels & distinct_indices
def _get_anchor_positive_triplet_mask(labels):
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(labels):
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check if labels[i] != labels[k]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
# Cell
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
"""Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid positive (they should have same label)
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# shape (batch_size, 1)
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative (they should have different labels)
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
tl = hardest_positive_dist - hardest_negative_dist + margin
tl = F.relu(tl)
triplet_loss = tl.mean()
return triplet_loss
# Cell
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
"""Build the triplet loss over a batch of embeddings.
We generate all the valid triplets and average the loss over the positive ones.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
anchor_positive_dist = pairwise_dist.unsqueeze(2)
anchor_negative_dist = pairwise_dist.unsqueeze(1)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
# Put to zero the invalid triplets
# (where label(a) != label(p) or label(n) == label(a) or a == p)
mask = _get_triplet_mask(labels)
triplet_loss = mask.float() * triplet_loss
# Remove negative losses (i.e. the easy triplets)
triplet_loss = F.relu(triplet_loss)
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = triplet_loss[triplet_loss > 1e-16]
num_positive_triplets = valid_triplets.size(0)
num_valid_triplets = mask.sum()
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
return triplet_loss, fraction_positive_triplets

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,5 @@
Top-1 accuracy: 0.8257482574825749
Top-3 accuracy: 0.9106191061910619
Top-5 accuracy: 0.9261992619926199
Top-10 accuracy: 0.942189421894219
0.0 0.82121104

View File

@ -0,0 +1,227 @@
# this code performs dataloading with text augmentation
# %%
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from transformers import (
AutoTokenizer,
)
from functools import partial
import re
import random
# %%
# PARAMETERS
SAMPLES=5
# %%
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def shuffle_text(text, prob=0.2):
if random.random() < prob:
words = text.split() # Split the input into words
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_text = " ".join(shuffled) # Join the words back into a string
else:
shuffled_text = text
return shuffled_text
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_text(sentence, prob=0.05):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < prob else word
for word in words
]
return " ".join(corrupted_words)
def strip_nonalphanumerics(desc, prob=0.5):
desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces
return desc
# %%
def augment(row):
"""
function to augment "mention" string input
returns the string input with slight variation
"""
desc = row['mention']
# we always apply preprocess
desc = preprocess_text(desc)
desc = shuffle_text(desc, prob=0.5)
desc = corrupt_text(desc, prob=0.5)
desc = strip_nonalphanumerics(desc, prob=0.5)
return desc
# %%
# custom dataset
# we want to sample n samples from each class
# sample_size refers to the number of samples per class
def sample_from_df(df, sample_size_per_class=5):
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
.reset_index(drop=True))
return sampled_df
# %%
class DynamicDataset(Dataset):
def __init__(self, df, sample_size_per_class):
"""
Args:
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
sample_size_per_class (int): Number of samples to draw per class for each epoch.
"""
self.df = df
self.sample_size_per_class = sample_size_per_class
self.current_data = None
self.regenerate_data() # Generate the initial dataset
def regenerate_data(self):
"""
Generate a new sampled dataset for the current epoch.
dynamic callback function to regenerate data each time we call this
method, it updates the current_data we can:
- re-sample the dataframe for a new set of n_samples
- generate fresh augmentations this effectively
This allows us to re-sample and re-augment at the start of each epoch
"""
# Sample `sample_size_per_class` rows per class
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
# Store the tokenized data with labels
self.current_data = sampled_df
def __len__(self):
return len(self.current_data)
def __getitem__(self, idx):
# do the transform here
row = self.current_data.iloc[idx].to_dict()
# perform text augmentation here
# independent function calls might introduce changes
mention_0 = augment(row)
mention_1 = augment(row)
return {
'training_id': row['training_id'],
'mention_0': mention_0,
'mention_1': mention_1,
}
# %%
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
dataset[0]
# %%
def custom_collate_fn(batch, tokenizer):
# batch is just a list of dictionaries
label_list = [item['training_id'] for item in batch]
mention_0_list = [item['mention_0'] for item in batch]
mention_1_list = [item['mention_1'] for item in batch]
# we can do the tokenization here
tokenized_batch_0 = tokenizer(
mention_0_list,
truncation=True,
padding=True,
return_tensors='pt'
)
tokenized_batch_1 = tokenizer(
mention_1_list,
truncation=True,
padding=True,
return_tensors='pt'
)
label_list = torch.tensor(label_list)
return {
'input_ids_0': tokenized_batch_0['input_ids'],
'attention_mask_0': tokenized_batch_0['attention_mask'],
'input_ids_1': tokenized_batch_1['input_ids'],
'attention_mask_1': tokenized_batch_1['attention_mask'],
'labels': label_list,
}
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=8,
collate_fn=custom_collate_fn_with_tokenizer
)
# %%
next(iter(dataloader))
# %%

View File

@ -0,0 +1,12 @@
Top-1 accuracy: 0.7107843137254902
Top-3 accuracy: 0.7990196078431373
Top-5 accuracy: 0.8137254901960784
Top-10 accuracy: 0.8529411764705882
Top-1 accuracy: 0.6862745098039216
Top-3 accuracy: 0.7696078431372549
Top-5 accuracy: 0.7990196078431373
Top-10 accuracy: 0.8480392156862745
Top-1 accuracy: 0.696078431372549
Top-3 accuracy: 0.7892156862745098
Top-5 accuracy: 0.8088235294117647
Top-10 accuracy: 0.8382352941176471

View File

@ -14,11 +14,10 @@ from transformers import AutoTokenizer, AutoModel
from data import generate_train_entity_sets
from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work
from tqdm import tqdm
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import logging
def setup(rank, world_size):
@ -73,13 +72,10 @@ def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, ma
loss.backward()
optimizer.step()
# logging.info(f'{epoch} {len(x)} {loss.item()}')
epoch_loss.append(loss.item())
epoch_len.append(len(x))
# del inputs, cls, loss
# torch.cuda.empty_cache()
logging.info(f'{DEVICE}{epoch_len}')
logging.info(f'{DEVICE}{epoch_loss}')
def check_label(predicted_cui: str, golden_cui: str) -> int:
"""
@ -127,8 +123,7 @@ def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, m
# print(np.min(distances), np.max(distances))
def save_checkpoint(model, res, epoch, dataName):
logging.info(f'Saving model {epoch} {res} ')
torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt')
torch.save(model.state_dict(), './checkpoint/' + dataName + '.pt')
class Model(nn.Module):
def __init__(self,MODEL_NAME):
@ -146,12 +141,12 @@ def main(rank, world_size, config):
setup(rank, world_size)
dataName = config['DEFAULT']['dataName']
logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO)
vocab = defaultdict(set)
with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f:
with open('../biomedical/' + dataName + '/' + config['train']['dictionary']) as f:
for line in f:
vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower())
line_list = line.strip().split('||')
vocab[line_list[0]].add(line_list[1].lower())
cui_to_id, id_to_cui = {}, {}
vocab_entity_id_mentions = {}
@ -167,7 +162,7 @@ def main(rank, world_size, config):
vocab_ids.extend([id]*len(mentions))
test_mentions, test_cuis = [], []
with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
with open('../biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
for line in f:
test_cuis.append(line.strip().split('||')[-1])
test_mentions.append(line.strip().split('||')[-2].lower())
@ -188,8 +183,7 @@ def main(rank, world_size, config):
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
best = 0
if rank == 0:
logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' )
best_res = []
for epoch in tqdm(range(epochs)):
trainDataLoader.sampler.set_epoch(epoch)
@ -197,11 +191,18 @@ def main(rank, world_size, config):
# if rank == 0 and epoch % 2 == 0:
if rank == 0:
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
logging.info(f'{epoch} {res}')
if res[0] > best:
best = res[0]
best_res = res
save_checkpoint(ddp_model.module, res, epoch, dataName)
with open("biomedical_results/output.txt", "a") as f:
print('new best ----', file=f)
for idx,n in enumerate([1,3,5,10]):
print(f'Top-{n:<3} accuracy: {best_res[idx]}', file=f)
dist.barrier()
cleanup()
if __name__ == '__main__':

View File

@ -0,0 +1,27 @@
[DEFAULT]
dataName = ncbi
# dataName = bc5cdr-disease
# dataName = bc5cdr-chemical
# dataName = bc2gm
[train]
epochs = 2
batch_size = 40
lr = 1e-5
dictionary = test_dictionary.txt
# or processed_dev_oneFile Note: bc2gm has no dev split
test_set = processed_test_refined
ckt_path = ./checkpoint/
[model]
margin = 2
model_name = dmis-lab/biobert-v1.1
[eval]
batch_size = 200
[data]
group_size = 4
shuffle = True

View File

@ -0,0 +1,84 @@
# %%
import torch
import json
import random
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModel
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
# %%
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
# split entity mentions into groups
entity_sets = []
# anchor means to use the entity_name as an anchor
if anchor:
# entity_id_mentions is a list of dicts, each dict is an id and a list of mentions
for id, mentions in entity_id_mentions.items():
# shuffle the mentions
random.shuffle(mentions)
# make batches of at most group size - 10
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
# the first element is always the entity_name
# this is why we use (group_size - 1)
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
entity_sets.extend(anchor_positive)
else:
# in this case, there is no entity_name in each group
for id, mentions in entity_id_mentions.items():
group = list(set([entity_id_name[id]] + mentions))
random.shuffle(group)
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
entity_sets.extend(positives)
return entity_sets
# the batch generator selects batch_size entries from the "data"
# but actually the "data" is a list of 10 items or less
def batchGenerator(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
x, y = [], []
for t in batch:
x.extend(t[0]) # this is the list of mentions
y.extend([t[1]]*len(t[0])) # this multiplies a single label by the number of mentions
yield x, y
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
num_sample_per_class = 10 # samples in each group
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
margin = 2
epochs = 200
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model.to(DEVICE)
model.train()
losses = []
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
# %%
random.shuffle(data)

227
vicreg/dataload.py Normal file
View File

@ -0,0 +1,227 @@
# this code performs dataloading with text augmentation
# %%
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from transformers import (
AutoTokenizer,
)
from functools import partial
import re
import random
# %%
# PARAMETERS
SAMPLES=5
# %%
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
##############################################################
# augmentation code
# basic preprocessing
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
def shuffle_text(text, prob=0.2):
if random.random() < prob:
words = text.split() # Split the input into words
shuffled = words[:] # Copy the word list to avoid in-place modification
random.shuffle(shuffled) # Randomly shuffle the words
shuffled_text = " ".join(shuffled) # Join the words back into a string
else:
shuffled_text = text
return shuffled_text
def corrupt_word(word):
"""Corrupt a single word using random corruption techniques."""
if len(word) <= 1: # Skip corruption for single-character words
return word
corruption_type = random.choice(["delete", "swap"])
if corruption_type == "delete":
# Randomly delete a character
idx = random.randint(0, len(word) - 1)
word = word[:idx] + word[idx + 1:]
elif corruption_type == "swap":
# Swap two adjacent characters
if len(word) > 1:
idx = random.randint(0, len(word) - 2)
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
return word
def corrupt_text(sentence, prob=0.05):
"""Corrupt each word in the string with a given probability."""
words = sentence.split()
corrupted_words = [
corrupt_word(word) if random.random() < prob else word
for word in words
]
return " ".join(corrupted_words)
def strip_nonalphanumerics(desc, prob=0.5):
desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces
return desc
# %%
def augment(row):
"""
function to augment "mention" string input
returns the string input with slight variation
"""
desc = row['mention']
# we always apply preprocess
desc = preprocess_text(desc)
desc = shuffle_text(desc, prob=1.0)
desc = corrupt_text(desc, prob=1.0)
desc = strip_nonalphanumerics(desc, prob=0.5)
return desc
# %%
# custom dataset
# we want to sample n samples from each class
# sample_size refers to the number of samples per class
def sample_from_df(df, sample_size_per_class=5):
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
.reset_index(drop=True))
return sampled_df
# %%
class DynamicDataset(Dataset):
def __init__(self, df, sample_size_per_class):
"""
Args:
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
sample_size_per_class (int): Number of samples to draw per class for each epoch.
"""
self.df = df
self.sample_size_per_class = sample_size_per_class
self.current_data = None
self.regenerate_data() # Generate the initial dataset
def regenerate_data(self):
"""
Generate a new sampled dataset for the current epoch.
dynamic callback function to regenerate data each time we call this
method, it updates the current_data we can:
- re-sample the dataframe for a new set of n_samples
- generate fresh augmentations this effectively
This allows us to re-sample and re-augment at the start of each epoch
"""
# Sample `sample_size_per_class` rows per class
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
# Store the tokenized data with labels
self.current_data = sampled_df
def __len__(self):
return len(self.current_data)
def __getitem__(self, idx):
# do the transform here
row = self.current_data.iloc[idx].to_dict()
# perform text augmentation here
# independent function calls might introduce changes
mention_0 = augment(row)
mention_1 = augment(row)
return {
'training_id': row['training_id'],
'mention_0': mention_0,
'mention_1': mention_1,
}
# %%
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
dataset[0]
# %%
def custom_collate_fn(batch, tokenizer):
# batch is just a list of dictionaries
label_list = [item['training_id'] for item in batch]
mention_0_list = [item['mention_0'] for item in batch]
mention_1_list = [item['mention_1'] for item in batch]
# we can do the tokenization here
tokenized_batch_0 = tokenizer(
mention_0_list,
truncation=True,
padding=True,
return_tensors='pt'
)
tokenized_batch_1 = tokenizer(
mention_1_list,
truncation=True,
padding=True,
return_tensors='pt'
)
label_list = torch.tensor(label_list)
return {
'input_ids_0': tokenized_batch_0['input_ids'],
'attention_mask_0': tokenized_batch_0['attention_mask'],
'input_ids_1': tokenized_batch_1['input_ids'],
'attention_mask_1': tokenized_batch_1['attention_mask'],
'labels': label_list,
}
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=8,
collate_fn=custom_collate_fn_with_tokenizer
)
# %%
next(iter(dataloader))
# %%

View File

@ -0,0 +1,5 @@
Top-1 accuracy: 0.01845018450184502
Top-3 accuracy: 0.02870028700287003
Top-5 accuracy: 0.03936039360393604
Top-10 accuracy: 0.08200082000820008
0.0 0.7957323

387
vicreg/train.py Normal file
View File

@ -0,0 +1,387 @@
# %%
# %%
from torch.utils.data import Dataset, DataLoader
# from datasets import load_from_disk
import os
import json
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
from dataclasses import dataclass
import re
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModel,
DataCollatorWithPadding,
Trainer,
EarlyStoppingCallback,
TrainingArguments,
TrainerCallback
)
import evaluate
import numpy as np
import pandas as pd
from functools import partial
import warnings
from tqdm import tqdm
from dataload import DynamicDataset, custom_collate_fn
from sklearn.neighbors import KNeighborsClassifier
torch.set_float32_matmul_precision('high')
def set_seed(seed):
"""
Set the random seed for reproducibility.
"""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch GPU
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
set_seed(42)
SAMPLES=40
BATCH_SIZE=256
# %%
###################################################
# import code
# import training file
data_path = '../esAppMod_data_import/train.csv'
df = pd.read_csv(data_path, skipinitialspace=True)
# rather than use pattern, we use the real thing and property
entity_ids = df['entity_id'].to_list()
target_id_list = sorted(list(set(entity_ids)))
id2label = {}
label2id = {}
for idx, val in enumerate(target_id_list):
id2label[idx] = val
label2id[val] = idx
df["training_id"] = df["entity_id"].map(label2id)
# %%
# make our dataset and dataloader
# MODEL_NAME = "distilbert/distilbert-base-uncased"
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=False)
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=custom_collate_fn_with_tokenizer
)
# %%
# enable BERT with projection layer
class VICRegProjection(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(VICRegProjection, self).__init__()
self.projection = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.projection(x)
class BertWithVICReg(nn.Module):
def __init__(self, bert_model, projection_dim=256):
super(BertWithVICReg, self).__init__()
self.bert = bert_model
hidden_size = bert_model.config.hidden_size
self.projection = VICRegProjection(input_dim=hidden_size, hidden_dim=hidden_size, output_dim=projection_dim)
def forward(self, input_ids, attention_mask=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:,0,:]
projected_embeddings = self.projection(pooled_output)
return projected_embeddings
#######################################################
# %%
@dataclass
class Hyperparameters:
loss_constant_factor: float = 1
invariance_loss_weight: float = 25.0
variance_loss_weight: float = 25.0
covariance_loss_weight: float = 1.0
variance_loss_epsilon: float = 1e-5
# compute vicreg loss
def get_vicreg_loss(z_a, z_b, hparams):
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
# invariance loss
loss_inv = F.mse_loss(z_a, z_b)
# variance loss
std_z_a = torch.sqrt(z_a.var(dim=0) + hparams.variance_loss_epsilon)
std_z_b = torch.sqrt(z_b.var(dim=0) + hparams.variance_loss_epsilon)
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
loss_v_b = torch.mean(F.relu(1 - std_z_b))
loss_var = loss_v_a + loss_v_b
# covariance loss
N, D = z_a.shape
z_a = z_a - z_a.mean(dim=0)
z_b = z_b - z_b.mean(dim=0)
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
loss_cov = loss_c_a + loss_c_b
weighted_inv = loss_inv * hparams.invariance_loss_weight
weighted_var = loss_var * hparams.variance_loss_weight
weighted_cov = loss_cov * hparams.covariance_loss_weight
loss = weighted_inv + weighted_var + weighted_cov
return loss
# %%
#
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
bert_model = AutoModel.from_pretrained(MODEL_NAME)
# bert_hidden_size = bert_model.config.hidden_size
# projection_model = VICRegProjection(
# input_dim=bert_hidden_size,
# hidden_dim=bert_hidden_size,
# output_dim=256
# )
# need to allocate individual component of the model
# bert_model.to(DEVICE)
# projection_model.to(DEVICE)
model = BertWithVICReg(bert_model, projection_dim=128)
model.to(DEVICE)
# params = list(bert_model.parameters()) + list(projection_model.parameters())
params = model.parameters()
optimizer = torch.optim.AdamW(params, lr=5e-6)
hparams = Hyperparameters()
losses = []
# # %%
# batch = next(iter(dataloader))
# input_ids_0 = batch['input_ids_0'].to(DEVICE)
# attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
#
# # %%
# # outputs from reprojection layer
# bert_output = model(
# input_ids=input_ids_0,
# attention_mask=attention_mask_0
# )
# %%
# parameters
epochs = 80
for epoch in tqdm(range(epochs)):
dataset.regenerate_data()
for batch in dataloader:
optimizer.zero_grad()
# compute cls 0
input_ids_0 = batch['input_ids_0'].to(DEVICE)
attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
# outputs from reprojection layer
outputs_0 = model(
input_ids=input_ids_0,
attention_mask=attention_mask_0
)
# compute cls 1
input_ids_1 = batch['input_ids_1'].to(DEVICE)
attention_mask_1 = batch['attention_mask_1'].to(DEVICE)
# outputs from reprojection layer
outputs_1 = model(
input_ids=input_ids_1,
attention_mask=attention_mask_1
)
loss = get_vicreg_loss(outputs_0, outputs_1, hparams=hparams)
loss.backward()
optimizer.step()
# print(epoch, loss)
losses.append(loss)
torch.cuda.empty_cache()
print(loss.detach().item())
# %%
torch.save(model.state_dict(), './checkpoint/simple.pt')
####################################################
# %%
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
state_dict = torch.load('./checkpoint/simple.pt')
model = BertWithVICReg(bert_model, projection_dim=256)
model.load_state_dict(state_dict)
# %%
# Step 2: Load the state dictionary
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# state_dict = torch.load('./checkpoint/siamese.pt')
# model = model.bert
# %%
# Step 3: Apply the state dictionary to the model
model.to(DEVICE)
model.eval()
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
def preprocess_text(text):
# 1. Make all uppercase
text = text.lower()
# standardize spacing
text = re.sub(r'\s+', ' ', text).strip()
return text
with open('../esAppMod/infer.json', 'r') as file:
test = json.load(file)
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
y_test = [d['entity_id'] for _, d in test['data'].items()]
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
train_entities = [preprocess_text(element) for element in train_entities]
def batch_list(data, batch_size):
"""Yield successive n-sized chunks from data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
batches = batch_list(train_entities, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
# output = outputs.last_hidden_state[:,0,:]
outputs = outputs.detach().cpu().numpy()
embedding_list.append(outputs)
cls = np.concatenate(embedding_list)
# %%
torch.cuda.empty_cache()
# %%
batches = batch_list(x_test, 64)
embedding_list = []
for batch in batches:
inputs = tokenizer(batch, padding=True, return_tensors='pt')
outputs = model(
input_ids=inputs['input_ids'].to(DEVICE),
attention_mask=inputs['attention_mask'].to(DEVICE)
)
# output = outputs.last_hidden_state[:,0,:]
outputs = outputs.detach().cpu().numpy()
embedding_list.append(outputs)
cls_test = np.concatenate(embedding_list)
# %%
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
n_neighbors = [1, 3, 5, 10]
for n in n_neighbors:
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
num = 0
for a,b in zip(y_test, indices):
b = [labels[i] for i in b]
if a in b:
num += 1
print(f'Top-{n:<3} accuracy: {num / len(y_test)}')
print(np.min(distances), np.max(distances))
# with open("results/output.txt", "w") as f:
# for n in n_neighbors:
# distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
# num = 0
# for a,b in zip(y_test, indices):
# b = [labels[i] for i in b]
# if a in b:
# num += 1
# print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
# print(np.min(distances), np.max(distances), file=f)
# %%
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# %%
# Reduce dimensions with t-SNE
tsne = TSNE(n_components=2, random_state=42)
embeddings= cls
embeddings_reduced = tsne.fit_transform(embeddings)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embeddings_reduced[:, 0], embeddings_reduced[:, 1], c=labels, cmap='viridis', alpha=0.6)
plt.colorbar(scatter)
plt.xlabel('Component 1')
plt.ylabel('Component 2')
plt.title('Visualization of Embeddings')
plt.show()
# %%