triplet loss with classification as a regularizer
- best record of 82.57 for esAppMod
This commit is contained in:
parent
94ee7beba7
commit
ac340f6fd2
|
@ -0,0 +1,2 @@
|
||||||
|
__pycache__
|
||||||
|
checkpoint
|
|
@ -0,0 +1,256 @@
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# from datasets import load_from_disk
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
os.environ['NCCL_P2P_DISABLE'] = '1'
|
||||||
|
os.environ['NCCL_IB_DISABLE'] = '1'
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModel,
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
)
|
||||||
|
import evaluate
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SIZE = 256
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# construct the target id list
|
||||||
|
# data_path = '../../../esAppMod_data_import/train.csv'
|
||||||
|
data_path = '../esAppMod_data_import/train.csv'
|
||||||
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
# rather than use pattern, we use the real thing and property
|
||||||
|
entity_ids = train_df['entity_id'].to_list()
|
||||||
|
target_id_list = sorted(list(set(entity_ids)))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
id2label = {}
|
||||||
|
label2id = {}
|
||||||
|
for idx, val in enumerate(target_id_list):
|
||||||
|
id2label[idx] = val
|
||||||
|
label2id[val] = idx
|
||||||
|
|
||||||
|
|
||||||
|
# introduce pre-processing functions
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# Substitute digits with '#'
|
||||||
|
# text = re.sub(r'\d+', '#', text)
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# outputs a list of dictionaries
|
||||||
|
# processes dataframe into lists of dictionaries
|
||||||
|
# each element maps input to output
|
||||||
|
# input: tag_description
|
||||||
|
# output: class label
|
||||||
|
def process_df_to_dict(df):
|
||||||
|
output_list = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
desc = row['mention']
|
||||||
|
desc = preprocess_text(desc)
|
||||||
|
index = row['entity_id']
|
||||||
|
element = {
|
||||||
|
'text' : desc,
|
||||||
|
'label': label2id[index], # ensure labels starts from 0
|
||||||
|
}
|
||||||
|
output_list.append(element)
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset():
|
||||||
|
# train
|
||||||
|
data_path = '../esAppMod_data_import/test.csv'
|
||||||
|
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
|
||||||
|
|
||||||
|
# combined_data = DatasetDict({
|
||||||
|
# 'train': Dataset.from_list(process_df_to_dict(train_df)),
|
||||||
|
# })
|
||||||
|
return Dataset.from_list(process_df_to_dict(test_df))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
def test():
|
||||||
|
|
||||||
|
test_dataset = create_dataset()
|
||||||
|
|
||||||
|
# prepare tokenizer
|
||||||
|
|
||||||
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||||
|
# Define additional special tokens
|
||||||
|
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||||
|
# Add the additional special tokens to the tokenizer
|
||||||
|
# tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# compute max token length
|
||||||
|
max_length = 0
|
||||||
|
for sample in test_dataset['text']:
|
||||||
|
# Tokenize the sample and get the length
|
||||||
|
input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"]
|
||||||
|
length = len(input_ids)
|
||||||
|
|
||||||
|
# Update max_length if this sample is longer
|
||||||
|
if length > max_length:
|
||||||
|
max_length = length
|
||||||
|
|
||||||
|
print(max_length)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
max_length = 128
|
||||||
|
|
||||||
|
# given a dataset entry, run it through the tokenizer
|
||||||
|
def preprocess_function(example):
|
||||||
|
input = example['text']
|
||||||
|
# text_target sets the corresponding label to inputs
|
||||||
|
# there is no need to create a separate 'labels'
|
||||||
|
model_inputs = tokenizer(
|
||||||
|
input,
|
||||||
|
# max_length=max_length,
|
||||||
|
# padding='max_length'
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
# map maps function to each "row" in the dataset
|
||||||
|
# aka the data in the immediate nesting
|
||||||
|
datasets = test_dataset.map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=8,
|
||||||
|
remove_columns="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||||
|
|
||||||
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
class BertForClassificationAndTriplet(nn.Module):
|
||||||
|
def __init__(self, bert_model, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.bert = bert_model
|
||||||
|
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None):
|
||||||
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||||
|
logits = self.classifier(cls_embeddings)
|
||||||
|
return cls_embeddings, logits
|
||||||
|
|
||||||
|
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||||
|
state_dict = torch.load('./checkpoint/classification.pt')
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
pred_labels = []
|
||||||
|
actual_labels = []
|
||||||
|
|
||||||
|
|
||||||
|
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
# Inference in batches
|
||||||
|
input_ids = batch['input_ids']
|
||||||
|
attention_mask = batch['attention_mask']
|
||||||
|
# save labels too
|
||||||
|
actual_labels.extend(batch['labels'])
|
||||||
|
|
||||||
|
|
||||||
|
# Move to GPU if available
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
attention_mask = attention_mask.to(device)
|
||||||
|
|
||||||
|
# Perform inference
|
||||||
|
with torch.no_grad():
|
||||||
|
cls, logits = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask)
|
||||||
|
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||||
|
pred_labels.extend(predicted_class_ids)
|
||||||
|
|
||||||
|
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
|
||||||
|
y_true = actual_labels
|
||||||
|
y_pred = pred_labels
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
accuracy = accuracy_score(y_true, y_pred)
|
||||||
|
average_parameter = 'weighted'
|
||||||
|
zero_division_parameter = 0
|
||||||
|
f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||||
|
precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||||
|
recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||||
|
|
||||||
|
with open("results/output.txt", "a") as f:
|
||||||
|
|
||||||
|
print('*' * 80, file=f)
|
||||||
|
# Print the results
|
||||||
|
print(f'Accuracy: {accuracy:.5f}', file=f)
|
||||||
|
print(f'F1 Score: {f1:.5f}', file=f)
|
||||||
|
print(f'Precision: {precision:.5f}', file=f)
|
||||||
|
print(f'Recall: {recall:.5f}', file=f)
|
||||||
|
|
||||||
|
# export result
|
||||||
|
label_list = [id2label[id] for id in pred_labels]
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'class_prediction': pd.Series(label_list)
|
||||||
|
})
|
||||||
|
|
||||||
|
# we can save the t5 generation output here
|
||||||
|
df.to_csv(f"results/classify.csv", index=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# reset file before writing to it
|
||||||
|
with open("results/output.txt", "w") as f:
|
||||||
|
print('', file=f)
|
||||||
|
test()
|
|
@ -0,0 +1,124 @@
|
||||||
|
# %%
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoModel
|
||||||
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from tqdm import tqdm
|
||||||
|
import re
|
||||||
|
import gc
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 2: Load the state dictionary
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||||
|
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||||
|
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||||
|
state_dict = torch.load('./checkpoint/classification.pt')
|
||||||
|
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 3: Apply the state dictionary to the model
|
||||||
|
model.load_state_dict(params_dict)
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||||
|
entities = json.load(file)
|
||||||
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||||
|
|
||||||
|
with open('../esAppMod/train.json', 'r') as file:
|
||||||
|
train = json.load(file)
|
||||||
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||||
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
with open('../esAppMod/infer.json', 'r') as file:
|
||||||
|
test = json.load(file)
|
||||||
|
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||||
|
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||||
|
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||||
|
train_entities = [preprocess_text(element) for element in train_entities]
|
||||||
|
|
||||||
|
def batch_list(data, batch_size):
|
||||||
|
"""Yield successive n-sized chunks from data."""
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
yield data[i:i + batch_size]
|
||||||
|
|
||||||
|
batches = batch_list(train_entities, 64)
|
||||||
|
|
||||||
|
embedding_list = []
|
||||||
|
for batch in batches:
|
||||||
|
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||||
|
outputs = model(
|
||||||
|
input_ids=inputs['input_ids'].to(DEVICE),
|
||||||
|
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||||
|
)
|
||||||
|
output = outputs.last_hidden_state[:,0,:]
|
||||||
|
output = output.detach().cpu().numpy()
|
||||||
|
embedding_list.append(output)
|
||||||
|
|
||||||
|
cls = np.concatenate(embedding_list)
|
||||||
|
# %%
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
batches = batch_list(x_test, 64)
|
||||||
|
|
||||||
|
embedding_list = []
|
||||||
|
for batch in batches:
|
||||||
|
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||||
|
outputs = model(
|
||||||
|
input_ids=inputs['input_ids'].to(DEVICE),
|
||||||
|
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||||
|
)
|
||||||
|
output = outputs.last_hidden_state[:,0,:]
|
||||||
|
output = output.detach().cpu().numpy()
|
||||||
|
embedding_list.append(output)
|
||||||
|
|
||||||
|
cls_test = np.concatenate(embedding_list)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||||
|
n_neighbors = [1, 3, 5, 10]
|
||||||
|
|
||||||
|
|
||||||
|
with open("results/output.txt", "w") as f:
|
||||||
|
for n in n_neighbors:
|
||||||
|
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||||
|
num = 0
|
||||||
|
for a,b in zip(y_test, indices):
|
||||||
|
b = [labels[i] for i in b]
|
||||||
|
if a in b:
|
||||||
|
num += 1
|
||||||
|
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||||
|
print(np.min(distances), np.max(distances), file=f)
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,276 @@
|
||||||
|
# %%
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoModel
|
||||||
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import re
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
SHUFFLES=0
|
||||||
|
AMPLIFY_FACTOR=0
|
||||||
|
LEARNING_RATE=1e-5
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||||
|
# split entity mentions into groups
|
||||||
|
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||||
|
entity_sets = []
|
||||||
|
if anchor:
|
||||||
|
for id, mentions in entity_id_mentions.items():
|
||||||
|
random.shuffle(mentions)
|
||||||
|
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||||
|
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||||
|
entity_sets.extend(anchor_positive)
|
||||||
|
else:
|
||||||
|
for id, mentions in entity_id_mentions.items():
|
||||||
|
group = list(set([entity_id_name[id]] + mentions))
|
||||||
|
random.shuffle(group)
|
||||||
|
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||||
|
entity_sets.extend(positives)
|
||||||
|
return entity_sets
|
||||||
|
|
||||||
|
def batchGenerator(data, batch_size):
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i+batch_size]
|
||||||
|
x, y = [], []
|
||||||
|
for t in batch:
|
||||||
|
x.extend(t[0])
|
||||||
|
y.extend([t[1]]*len(t[0]))
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||||
|
entities = json.load(file)
|
||||||
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||||
|
|
||||||
|
with open('../esAppMod/train.json', 'r') as file:
|
||||||
|
train = json.load(file)
|
||||||
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||||
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||||
|
|
||||||
|
# %%
|
||||||
|
###############
|
||||||
|
# alternate data import strategy
|
||||||
|
###################################################
|
||||||
|
# import code
|
||||||
|
# import training file
|
||||||
|
data_path = '../esAppMod_data_import/train.csv'
|
||||||
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
# rather than use pattern, we use the real thing and property
|
||||||
|
entity_ids = df['entity_id'].to_list()
|
||||||
|
target_id_list = sorted(list(set(entity_ids)))
|
||||||
|
|
||||||
|
id2label = {}
|
||||||
|
label2id = {}
|
||||||
|
for idx, val in enumerate(target_id_list):
|
||||||
|
id2label[idx] = val
|
||||||
|
label2id[val] = idx
|
||||||
|
|
||||||
|
df["training_id"] = df["entity_id"].map(label2id)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
##############################################################
|
||||||
|
# augmentation code
|
||||||
|
|
||||||
|
# basic preprocessing
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_shuffles(text, n):
|
||||||
|
words = text.split() # Split the input into words
|
||||||
|
shuffled_variations = []
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||||
|
random.shuffle(shuffled) # Randomly shuffle the words
|
||||||
|
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||||
|
|
||||||
|
return shuffled_variations
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||||
|
all_processed = []
|
||||||
|
# add the original text
|
||||||
|
all_processed.append(text)
|
||||||
|
|
||||||
|
# Generate random shuffles
|
||||||
|
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||||
|
all_processed.extend(shuffled_variations)
|
||||||
|
|
||||||
|
return all_processed
|
||||||
|
|
||||||
|
def corrupt_word(word):
|
||||||
|
"""Corrupt a single word using random corruption techniques."""
|
||||||
|
if len(word) <= 1: # Skip corruption for single-character words
|
||||||
|
return word
|
||||||
|
|
||||||
|
corruption_type = random.choice(["delete", "swap"])
|
||||||
|
|
||||||
|
if corruption_type == "delete":
|
||||||
|
# Randomly delete a character
|
||||||
|
idx = random.randint(0, len(word) - 1)
|
||||||
|
word = word[:idx] + word[idx + 1:]
|
||||||
|
|
||||||
|
elif corruption_type == "swap":
|
||||||
|
# Swap two adjacent characters
|
||||||
|
if len(word) > 1:
|
||||||
|
idx = random.randint(0, len(word) - 2)
|
||||||
|
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||||
|
|
||||||
|
|
||||||
|
return word
|
||||||
|
|
||||||
|
def corrupt_string(sentence, corruption_probability=0.01):
|
||||||
|
"""Corrupt each word in the string with a given probability."""
|
||||||
|
words = sentence.split()
|
||||||
|
corrupted_words = [
|
||||||
|
corrupt_word(word) if random.random() < corruption_probability else word
|
||||||
|
for word in words
|
||||||
|
]
|
||||||
|
return " ".join(corrupted_words)
|
||||||
|
|
||||||
|
|
||||||
|
def create_example(index, mention, entity_name):
|
||||||
|
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||||
|
|
||||||
|
# augment whole dataset
|
||||||
|
def augment_data(df):
|
||||||
|
output_list = []
|
||||||
|
|
||||||
|
for idx,row in df.iterrows():
|
||||||
|
index = row['entity_id']
|
||||||
|
entity_name = row['entity_name']
|
||||||
|
parent_desc = row['mention']
|
||||||
|
parent_desc = preprocess_text(parent_desc)
|
||||||
|
|
||||||
|
# add basic example
|
||||||
|
output_list.append(create_example(index, parent_desc, entity_name))
|
||||||
|
|
||||||
|
# add shuffled strings
|
||||||
|
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||||
|
for desc in processed_descs:
|
||||||
|
if (desc != parent_desc):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
# add corrupted strings
|
||||||
|
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||||
|
if (desc != parent_desc):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
# add example with stripped non-alphanumerics
|
||||||
|
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||||
|
if (desc != parent_desc):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
# short sequence amplifier
|
||||||
|
# short sequences are rare, and we must compensate by including more examples
|
||||||
|
# also, short sequence don't usually get affected by shuffle
|
||||||
|
words = parent_desc.split()
|
||||||
|
word_count = len(words)
|
||||||
|
if word_count <= 2:
|
||||||
|
for _ in range(AMPLIFY_FACTOR):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
new_df = pd.DataFrame(output_list)
|
||||||
|
return new_df
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def make_entity_id_mentions(df):
|
||||||
|
entity_id_mentions = {}
|
||||||
|
entity_id_list = list(set(df['entity_id']))
|
||||||
|
for entity_id in entity_id_list:
|
||||||
|
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||||
|
return entity_id_mentions
|
||||||
|
|
||||||
|
def make_entity_id_name(df):
|
||||||
|
entity_id_name = {}
|
||||||
|
entity_id_list = list(set(df['entity_id']))
|
||||||
|
for entity_id in entity_id_list:
|
||||||
|
# entity_id always matches entity_name, so first value would work
|
||||||
|
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||||
|
return entity_id_name
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
num_sample_per_class = 10 # samples in each group
|
||||||
|
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||||
|
margin = 2
|
||||||
|
epochs = 200
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||||
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for epoch in tqdm(range(epochs)):
|
||||||
|
total_loss = 0.0
|
||||||
|
batch_number = 0
|
||||||
|
|
||||||
|
augmented_df = augment_data(df)
|
||||||
|
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||||
|
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||||
|
|
||||||
|
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||||
|
random.shuffle(data)
|
||||||
|
for x,y in batchGenerator(data, batch_size):
|
||||||
|
|
||||||
|
# print(len(x), len(y), end='-->')
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||||
|
inputs.to(DEVICE)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
cls = outputs.last_hidden_state[:,0,:]
|
||||||
|
# for training less than half the time, train on easy
|
||||||
|
y = torch.tensor(y).to(DEVICE)
|
||||||
|
if epoch < epochs / 2:
|
||||||
|
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||||
|
# for training after half the time, train on hard
|
||||||
|
else:
|
||||||
|
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.detach().item()
|
||||||
|
batch_number += 1
|
||||||
|
|
||||||
|
del x, y, outputs, cls, loss
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# scheduler.step() # Update the learning rate
|
||||||
|
print(f'epoch loss: {total_loss/batch_number}')
|
||||||
|
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||||
|
if epoch % 5 == 0:
|
||||||
|
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
||||||
|
# %%
|
|
@ -0,0 +1,305 @@
|
||||||
|
# %%
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoModel
|
||||||
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import re
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# %%
|
||||||
|
SHUFFLES=0
|
||||||
|
AMPLIFY_FACTOR=2
|
||||||
|
LEARNING_RATE=1e-5
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||||
|
# split entity mentions into groups
|
||||||
|
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||||
|
entity_sets = []
|
||||||
|
if anchor:
|
||||||
|
for id, mentions in entity_id_mentions.items():
|
||||||
|
random.shuffle(mentions)
|
||||||
|
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||||
|
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||||
|
entity_sets.extend(anchor_positive)
|
||||||
|
else:
|
||||||
|
for id, mentions in entity_id_mentions.items():
|
||||||
|
group = list(set([entity_id_name[id]] + mentions))
|
||||||
|
random.shuffle(group)
|
||||||
|
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||||
|
entity_sets.extend(positives)
|
||||||
|
return entity_sets
|
||||||
|
|
||||||
|
def batchGenerator(data, batch_size):
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i+batch_size]
|
||||||
|
x, y = [], []
|
||||||
|
for t in batch:
|
||||||
|
x.extend(t[0])
|
||||||
|
y.extend([t[1]]*len(t[0]))
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||||
|
entities = json.load(file)
|
||||||
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||||
|
|
||||||
|
with open('../esAppMod/train.json', 'r') as file:
|
||||||
|
train = json.load(file)
|
||||||
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||||
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||||
|
|
||||||
|
# %%
|
||||||
|
###############
|
||||||
|
# alternate data import strategy
|
||||||
|
###################################################
|
||||||
|
# import code
|
||||||
|
# import training file
|
||||||
|
data_path = '../esAppMod_data_import/train.csv'
|
||||||
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
# rather than use pattern, we use the real thing and property
|
||||||
|
entity_ids = df['entity_id'].to_list()
|
||||||
|
target_id_list = sorted(list(set(entity_ids)))
|
||||||
|
|
||||||
|
id2label = {}
|
||||||
|
label2id = {}
|
||||||
|
for idx, val in enumerate(target_id_list):
|
||||||
|
id2label[idx] = val
|
||||||
|
label2id[val] = idx
|
||||||
|
|
||||||
|
df["training_id"] = df["entity_id"].map(label2id)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
##############################################################
|
||||||
|
# augmentation code
|
||||||
|
|
||||||
|
# basic preprocessing
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_shuffles(text, n):
|
||||||
|
words = text.split() # Split the input into words
|
||||||
|
shuffled_variations = []
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||||
|
random.shuffle(shuffled) # Randomly shuffle the words
|
||||||
|
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||||
|
|
||||||
|
return shuffled_variations
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||||
|
all_processed = []
|
||||||
|
# add the original text
|
||||||
|
all_processed.append(text)
|
||||||
|
|
||||||
|
# Generate random shuffles
|
||||||
|
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||||
|
all_processed.extend(shuffled_variations)
|
||||||
|
|
||||||
|
return all_processed
|
||||||
|
|
||||||
|
def corrupt_word(word):
|
||||||
|
"""Corrupt a single word using random corruption techniques."""
|
||||||
|
if len(word) <= 1: # Skip corruption for single-character words
|
||||||
|
return word
|
||||||
|
|
||||||
|
corruption_type = random.choice(["delete", "swap"])
|
||||||
|
|
||||||
|
if corruption_type == "delete":
|
||||||
|
# Randomly delete a character
|
||||||
|
idx = random.randint(0, len(word) - 1)
|
||||||
|
word = word[:idx] + word[idx + 1:]
|
||||||
|
|
||||||
|
elif corruption_type == "swap":
|
||||||
|
# Swap two adjacent characters
|
||||||
|
if len(word) > 1:
|
||||||
|
idx = random.randint(0, len(word) - 2)
|
||||||
|
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||||
|
|
||||||
|
|
||||||
|
return word
|
||||||
|
|
||||||
|
def corrupt_string(sentence, corruption_probability=0.01):
|
||||||
|
"""Corrupt each word in the string with a given probability."""
|
||||||
|
words = sentence.split()
|
||||||
|
corrupted_words = [
|
||||||
|
corrupt_word(word) if random.random() < corruption_probability else word
|
||||||
|
for word in words
|
||||||
|
]
|
||||||
|
return " ".join(corrupted_words)
|
||||||
|
|
||||||
|
|
||||||
|
def create_example(index, mention, entity_name):
|
||||||
|
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||||
|
|
||||||
|
# augment whole dataset
|
||||||
|
def augment_data(df):
|
||||||
|
output_list = []
|
||||||
|
|
||||||
|
for idx,row in df.iterrows():
|
||||||
|
index = row['entity_id']
|
||||||
|
entity_name = row['entity_name']
|
||||||
|
parent_desc = row['mention']
|
||||||
|
parent_desc = preprocess_text(parent_desc)
|
||||||
|
|
||||||
|
# add basic example
|
||||||
|
output_list.append(create_example(index, parent_desc, entity_name))
|
||||||
|
|
||||||
|
# add shuffled strings
|
||||||
|
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||||
|
for desc in processed_descs:
|
||||||
|
if (desc != parent_desc):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
# add corrupted strings
|
||||||
|
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||||
|
if (desc != parent_desc):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
# add example with stripped non-alphanumerics
|
||||||
|
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||||
|
if (desc != parent_desc):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
# short sequence amplifier
|
||||||
|
# short sequences are rare, and we must compensate by including more examples
|
||||||
|
# also, short sequence don't usually get affected by shuffle
|
||||||
|
words = parent_desc.split()
|
||||||
|
word_count = len(words)
|
||||||
|
if word_count <= 2:
|
||||||
|
for _ in range(AMPLIFY_FACTOR):
|
||||||
|
output_list.append(create_example(index, desc, entity_name))
|
||||||
|
|
||||||
|
new_df = pd.DataFrame(output_list)
|
||||||
|
return new_df
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def make_entity_id_mentions(df):
|
||||||
|
entity_id_mentions = {}
|
||||||
|
entity_id_list = list(set(df['entity_id']))
|
||||||
|
for entity_id in entity_id_list:
|
||||||
|
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||||
|
return entity_id_mentions
|
||||||
|
|
||||||
|
def make_entity_id_name(df):
|
||||||
|
entity_id_name = {}
|
||||||
|
entity_id_list = list(set(df['entity_id']))
|
||||||
|
for entity_id in entity_id_list:
|
||||||
|
# entity_id always matches entity_name, so first value would work
|
||||||
|
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||||
|
return entity_id_name
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
num_sample_per_class = 10 # samples in each group
|
||||||
|
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||||
|
margin = 2
|
||||||
|
epochs = 200
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||||
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
class BertForClassificationAndTriplet(nn.Module):
|
||||||
|
def __init__(self, bert_model, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.bert = bert_model
|
||||||
|
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None):
|
||||||
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||||
|
logits = self.classifier(cls_embeddings)
|
||||||
|
return cls_embeddings, logits
|
||||||
|
|
||||||
|
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for epoch in tqdm(range(epochs)):
|
||||||
|
total_loss = 0.0
|
||||||
|
batch_number = 0
|
||||||
|
|
||||||
|
augmented_df = augment_data(df)
|
||||||
|
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||||
|
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||||
|
|
||||||
|
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||||
|
random.shuffle(data)
|
||||||
|
for x,y in batchGenerator(data, batch_size):
|
||||||
|
|
||||||
|
# print(len(x), len(y), end='-->')
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||||
|
inputs.to(DEVICE)
|
||||||
|
cls, logits = model(
|
||||||
|
input_ids=inputs['input_ids'],
|
||||||
|
attention_mask=inputs['attention_mask']
|
||||||
|
)
|
||||||
|
# for training less than half the time, train on easy
|
||||||
|
labels = y
|
||||||
|
labels = [label2id[element] for element in labels]
|
||||||
|
labels = torch.tensor(labels).to(DEVICE)
|
||||||
|
y = torch.tensor(y).to(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
class_loss = F.cross_entropy(logits, labels)
|
||||||
|
|
||||||
|
if epoch < epochs / 2:
|
||||||
|
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||||
|
# for training after half the time, train on hard
|
||||||
|
else:
|
||||||
|
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||||
|
|
||||||
|
loss = class_loss + triplet_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.detach().item()
|
||||||
|
batch_number += 1
|
||||||
|
|
||||||
|
del x, y, cls, logits, loss
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# scheduler.step() # Update the learning rate
|
||||||
|
print(f'epoch loss: {total_loss/batch_number}')
|
||||||
|
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||||
|
if epoch % 5 == 0:
|
||||||
|
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||||
|
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||||
|
|
||||||
|
|
||||||
|
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||||
|
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||||
|
# %%
|
|
@ -0,0 +1,186 @@
|
||||||
|
# stardard functionalities for computing triplet loss, borrow code from
|
||||||
|
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
def _pairwise_distances(embeddings, squared=False):
|
||||||
|
"""Compute the 2D matrix of distances between all the embeddings.
|
||||||
|
Args:
|
||||||
|
embeddings: tensor of shape (batch_size, embed_dim)
|
||||||
|
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||||
|
If false, output is the pairwise euclidean distance matrix.
|
||||||
|
Returns:
|
||||||
|
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||||
|
"""
|
||||||
|
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||||
|
|
||||||
|
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||||
|
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||||
|
# shape (batch_size,)
|
||||||
|
square_norm = torch.diag(dot_product)
|
||||||
|
|
||||||
|
# Compute the pairwise distance matrix as we have:
|
||||||
|
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||||
|
# shape (batch_size, batch_size)
|
||||||
|
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||||
|
|
||||||
|
# Because of computation errors, some distances might be negative so we put everything >= 0.0
|
||||||
|
distances[distances < 0] = 0
|
||||||
|
|
||||||
|
if not squared:
|
||||||
|
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||||
|
# we need to add a small epsilon where distances == 0.0
|
||||||
|
mask = distances.eq(0).float()
|
||||||
|
distances = distances + mask * 1e-16
|
||||||
|
|
||||||
|
distances = (1.0 -mask) * torch.sqrt(distances)
|
||||||
|
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def _get_triplet_mask(labels):
|
||||||
|
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
||||||
|
A triplet (i, j, k) is valid if:
|
||||||
|
- i, j, k are distinct
|
||||||
|
- labels[i] == labels[j] and labels[i] != labels[k]
|
||||||
|
Args:
|
||||||
|
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||||
|
"""
|
||||||
|
# Check that i, j and k are distinct
|
||||||
|
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||||
|
indices_not_equal = ~indices_equal
|
||||||
|
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||||
|
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||||
|
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||||
|
|
||||||
|
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||||
|
|
||||||
|
|
||||||
|
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||||
|
i_equal_j = label_equal.unsqueeze(2)
|
||||||
|
i_equal_k = label_equal.unsqueeze(1)
|
||||||
|
|
||||||
|
valid_labels = ~i_equal_k & i_equal_j
|
||||||
|
|
||||||
|
return valid_labels & distinct_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _get_anchor_positive_triplet_mask(labels):
|
||||||
|
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
||||||
|
Args:
|
||||||
|
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||||
|
Returns:
|
||||||
|
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||||
|
"""
|
||||||
|
# Check that i and j are distinct
|
||||||
|
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||||
|
indices_not_equal = ~indices_equal
|
||||||
|
|
||||||
|
# Check if labels[i] == labels[j]
|
||||||
|
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||||
|
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||||
|
|
||||||
|
return labels_equal & indices_not_equal
|
||||||
|
|
||||||
|
|
||||||
|
def _get_anchor_negative_triplet_mask(labels):
|
||||||
|
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
||||||
|
Args:
|
||||||
|
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||||
|
Returns:
|
||||||
|
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||||
|
"""
|
||||||
|
# Check if labels[i] != labels[k]
|
||||||
|
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||||
|
|
||||||
|
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
||||||
|
|
||||||
|
|
||||||
|
# Cell
|
||||||
|
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
|
||||||
|
"""Build the triplet loss over a batch of embeddings.
|
||||||
|
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||||
|
Args:
|
||||||
|
labels: labels of the batch, of size (batch_size,)
|
||||||
|
embeddings: tensor of shape (batch_size, embed_dim)
|
||||||
|
margin: margin for triplet loss
|
||||||
|
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||||
|
If false, output is the pairwise euclidean distance matrix.
|
||||||
|
Returns:
|
||||||
|
triplet_loss: scalar tensor containing the triplet loss
|
||||||
|
"""
|
||||||
|
# Get the pairwise distance matrix
|
||||||
|
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||||
|
|
||||||
|
# For each anchor, get the hardest positive
|
||||||
|
# First, we need to get a mask for every valid positive (they should have same label)
|
||||||
|
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||||
|
|
||||||
|
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||||
|
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||||
|
|
||||||
|
# shape (batch_size, 1)
|
||||||
|
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||||
|
|
||||||
|
# For each anchor, get the hardest negative
|
||||||
|
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||||
|
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||||
|
|
||||||
|
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||||
|
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||||
|
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||||
|
|
||||||
|
# shape (batch_size,)
|
||||||
|
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||||
|
tl = hardest_positive_dist - hardest_negative_dist + margin
|
||||||
|
tl = F.relu(tl)
|
||||||
|
triplet_loss = tl.mean()
|
||||||
|
|
||||||
|
return triplet_loss
|
||||||
|
|
||||||
|
# Cell
|
||||||
|
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
|
||||||
|
"""Build the triplet loss over a batch of embeddings.
|
||||||
|
We generate all the valid triplets and average the loss over the positive ones.
|
||||||
|
Args:
|
||||||
|
labels: labels of the batch, of size (batch_size,)
|
||||||
|
embeddings: tensor of shape (batch_size, embed_dim)
|
||||||
|
margin: margin for triplet loss
|
||||||
|
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||||
|
If false, output is the pairwise euclidean distance matrix.
|
||||||
|
Returns:
|
||||||
|
triplet_loss: scalar tensor containing the triplet loss
|
||||||
|
"""
|
||||||
|
# Get the pairwise distance matrix
|
||||||
|
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||||
|
|
||||||
|
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||||
|
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||||
|
|
||||||
|
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||||
|
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||||
|
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||||
|
# and the 2nd (batch_size, 1, batch_size)
|
||||||
|
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Put to zero the invalid triplets
|
||||||
|
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||||
|
mask = _get_triplet_mask(labels)
|
||||||
|
triplet_loss = mask.float() * triplet_loss
|
||||||
|
|
||||||
|
# Remove negative losses (i.e. the easy triplets)
|
||||||
|
triplet_loss = F.relu(triplet_loss)
|
||||||
|
|
||||||
|
# Count number of positive triplets (where triplet_loss > 0)
|
||||||
|
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||||
|
num_positive_triplets = valid_triplets.size(0)
|
||||||
|
num_valid_triplets = mask.sum()
|
||||||
|
|
||||||
|
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||||
|
|
||||||
|
# Get final mean triplet loss over the positive valid triplets
|
||||||
|
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||||
|
|
||||||
|
return triplet_loss, fraction_positive_triplets
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,5 @@
|
||||||
|
Top-1 accuracy: 0.8257482574825749
|
||||||
|
Top-3 accuracy: 0.9106191061910619
|
||||||
|
Top-5 accuracy: 0.9261992619926199
|
||||||
|
Top-10 accuracy: 0.942189421894219
|
||||||
|
0.0 0.82121104
|
|
@ -0,0 +1,227 @@
|
||||||
|
# this code performs dataloading with text augmentation
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# PARAMETERS
|
||||||
|
SAMPLES=5
|
||||||
|
|
||||||
|
# %%
|
||||||
|
###################################################
|
||||||
|
# import code
|
||||||
|
# import training file
|
||||||
|
data_path = '../esAppMod_data_import/train.csv'
|
||||||
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
# rather than use pattern, we use the real thing and property
|
||||||
|
entity_ids = df['entity_id'].to_list()
|
||||||
|
target_id_list = sorted(list(set(entity_ids)))
|
||||||
|
|
||||||
|
id2label = {}
|
||||||
|
label2id = {}
|
||||||
|
for idx, val in enumerate(target_id_list):
|
||||||
|
id2label[idx] = val
|
||||||
|
label2id[val] = idx
|
||||||
|
|
||||||
|
df["training_id"] = df["entity_id"].map(label2id)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
##############################################################
|
||||||
|
# augmentation code
|
||||||
|
|
||||||
|
# basic preprocessing
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_text(text, prob=0.2):
|
||||||
|
if random.random() < prob:
|
||||||
|
words = text.split() # Split the input into words
|
||||||
|
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||||
|
random.shuffle(shuffled) # Randomly shuffle the words
|
||||||
|
shuffled_text = " ".join(shuffled) # Join the words back into a string
|
||||||
|
else:
|
||||||
|
shuffled_text = text
|
||||||
|
|
||||||
|
return shuffled_text
|
||||||
|
|
||||||
|
|
||||||
|
def corrupt_word(word):
|
||||||
|
"""Corrupt a single word using random corruption techniques."""
|
||||||
|
if len(word) <= 1: # Skip corruption for single-character words
|
||||||
|
return word
|
||||||
|
|
||||||
|
corruption_type = random.choice(["delete", "swap"])
|
||||||
|
|
||||||
|
if corruption_type == "delete":
|
||||||
|
# Randomly delete a character
|
||||||
|
idx = random.randint(0, len(word) - 1)
|
||||||
|
word = word[:idx] + word[idx + 1:]
|
||||||
|
|
||||||
|
elif corruption_type == "swap":
|
||||||
|
# Swap two adjacent characters
|
||||||
|
if len(word) > 1:
|
||||||
|
idx = random.randint(0, len(word) - 2)
|
||||||
|
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||||
|
|
||||||
|
|
||||||
|
return word
|
||||||
|
|
||||||
|
def corrupt_text(sentence, prob=0.05):
|
||||||
|
"""Corrupt each word in the string with a given probability."""
|
||||||
|
words = sentence.split()
|
||||||
|
corrupted_words = [
|
||||||
|
corrupt_word(word) if random.random() < prob else word
|
||||||
|
for word in words
|
||||||
|
]
|
||||||
|
return " ".join(corrupted_words)
|
||||||
|
|
||||||
|
def strip_nonalphanumerics(desc, prob=0.5):
|
||||||
|
desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces
|
||||||
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def augment(row):
|
||||||
|
"""
|
||||||
|
function to augment "mention" string input
|
||||||
|
returns the string input with slight variation
|
||||||
|
"""
|
||||||
|
desc = row['mention']
|
||||||
|
# we always apply preprocess
|
||||||
|
desc = preprocess_text(desc)
|
||||||
|
|
||||||
|
desc = shuffle_text(desc, prob=0.5)
|
||||||
|
desc = corrupt_text(desc, prob=0.5)
|
||||||
|
desc = strip_nonalphanumerics(desc, prob=0.5)
|
||||||
|
|
||||||
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# custom dataset
|
||||||
|
# we want to sample n samples from each class
|
||||||
|
# sample_size refers to the number of samples per class
|
||||||
|
def sample_from_df(df, sample_size_per_class=5):
|
||||||
|
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
|
||||||
|
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||||
|
.reset_index(drop=True))
|
||||||
|
|
||||||
|
return sampled_df
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class DynamicDataset(Dataset):
|
||||||
|
def __init__(self, df, sample_size_per_class):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
|
||||||
|
sample_size_per_class (int): Number of samples to draw per class for each epoch.
|
||||||
|
"""
|
||||||
|
self.df = df
|
||||||
|
self.sample_size_per_class = sample_size_per_class
|
||||||
|
self.current_data = None
|
||||||
|
self.regenerate_data() # Generate the initial dataset
|
||||||
|
|
||||||
|
def regenerate_data(self):
|
||||||
|
"""
|
||||||
|
Generate a new sampled dataset for the current epoch.
|
||||||
|
|
||||||
|
dynamic callback function to regenerate data each time we call this
|
||||||
|
method, it updates the current_data we can:
|
||||||
|
|
||||||
|
- re-sample the dataframe for a new set of n_samples
|
||||||
|
- generate fresh augmentations this effectively
|
||||||
|
|
||||||
|
This allows us to re-sample and re-augment at the start of each epoch
|
||||||
|
"""
|
||||||
|
# Sample `sample_size_per_class` rows per class
|
||||||
|
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
|
||||||
|
|
||||||
|
# Store the tokenized data with labels
|
||||||
|
self.current_data = sampled_df
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.current_data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# do the transform here
|
||||||
|
row = self.current_data.iloc[idx].to_dict()
|
||||||
|
|
||||||
|
# perform text augmentation here
|
||||||
|
# independent function calls might introduce changes
|
||||||
|
mention_0 = augment(row)
|
||||||
|
mention_1 = augment(row)
|
||||||
|
return {
|
||||||
|
'training_id': row['training_id'],
|
||||||
|
'mention_0': mention_0,
|
||||||
|
'mention_1': mention_1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def custom_collate_fn(batch, tokenizer):
|
||||||
|
# batch is just a list of dictionaries
|
||||||
|
label_list = [item['training_id'] for item in batch]
|
||||||
|
mention_0_list = [item['mention_0'] for item in batch]
|
||||||
|
mention_1_list = [item['mention_1'] for item in batch]
|
||||||
|
|
||||||
|
# we can do the tokenization here
|
||||||
|
tokenized_batch_0 = tokenizer(
|
||||||
|
mention_0_list,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_batch_1 = tokenizer(
|
||||||
|
mention_1_list,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
label_list = torch.tensor(label_list)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids_0': tokenized_batch_0['input_ids'],
|
||||||
|
'attention_mask_0': tokenized_batch_0['attention_mask'],
|
||||||
|
'input_ids_1': tokenized_batch_1['input_ids'],
|
||||||
|
'attention_mask_1': tokenized_batch_1['attention_mask'],
|
||||||
|
'labels': label_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
|
||||||
|
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=8,
|
||||||
|
collate_fn=custom_collate_fn_with_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
next(iter(dataloader))
|
||||||
|
# %%
|
|
@ -0,0 +1,12 @@
|
||||||
|
Top-1 accuracy: 0.7107843137254902
|
||||||
|
Top-3 accuracy: 0.7990196078431373
|
||||||
|
Top-5 accuracy: 0.8137254901960784
|
||||||
|
Top-10 accuracy: 0.8529411764705882
|
||||||
|
Top-1 accuracy: 0.6862745098039216
|
||||||
|
Top-3 accuracy: 0.7696078431372549
|
||||||
|
Top-5 accuracy: 0.7990196078431373
|
||||||
|
Top-10 accuracy: 0.8480392156862745
|
||||||
|
Top-1 accuracy: 0.696078431372549
|
||||||
|
Top-3 accuracy: 0.7892156862745098
|
||||||
|
Top-5 accuracy: 0.8088235294117647
|
||||||
|
Top-10 accuracy: 0.8382352941176471
|
|
@ -14,11 +14,10 @@ from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
from data import generate_train_entity_sets
|
from data import generate_train_entity_sets
|
||||||
|
|
||||||
from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work
|
from tqdm import tqdm
|
||||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||||
from sklearn.neighbors import KNeighborsClassifier
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
def setup(rank, world_size):
|
def setup(rank, world_size):
|
||||||
|
@ -73,13 +72,10 @@ def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, ma
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# logging.info(f'{epoch} {len(x)} {loss.item()}')
|
|
||||||
epoch_loss.append(loss.item())
|
epoch_loss.append(loss.item())
|
||||||
epoch_len.append(len(x))
|
epoch_len.append(len(x))
|
||||||
# del inputs, cls, loss
|
# del inputs, cls, loss
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
logging.info(f'{DEVICE}{epoch_len}')
|
|
||||||
logging.info(f'{DEVICE}{epoch_loss}')
|
|
||||||
|
|
||||||
def check_label(predicted_cui: str, golden_cui: str) -> int:
|
def check_label(predicted_cui: str, golden_cui: str) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -127,8 +123,7 @@ def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, m
|
||||||
# print(np.min(distances), np.max(distances))
|
# print(np.min(distances), np.max(distances))
|
||||||
|
|
||||||
def save_checkpoint(model, res, epoch, dataName):
|
def save_checkpoint(model, res, epoch, dataName):
|
||||||
logging.info(f'Saving model {epoch} {res} ')
|
torch.save(model.state_dict(), './checkpoint/' + dataName + '.pt')
|
||||||
torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt')
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self,MODEL_NAME):
|
def __init__(self,MODEL_NAME):
|
||||||
|
@ -146,12 +141,12 @@ def main(rank, world_size, config):
|
||||||
setup(rank, world_size)
|
setup(rank, world_size)
|
||||||
|
|
||||||
dataName = config['DEFAULT']['dataName']
|
dataName = config['DEFAULT']['dataName']
|
||||||
logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO)
|
|
||||||
|
|
||||||
vocab = defaultdict(set)
|
vocab = defaultdict(set)
|
||||||
with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f:
|
with open('../biomedical/' + dataName + '/' + config['train']['dictionary']) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower())
|
line_list = line.strip().split('||')
|
||||||
|
vocab[line_list[0]].add(line_list[1].lower())
|
||||||
|
|
||||||
cui_to_id, id_to_cui = {}, {}
|
cui_to_id, id_to_cui = {}, {}
|
||||||
vocab_entity_id_mentions = {}
|
vocab_entity_id_mentions = {}
|
||||||
|
@ -167,7 +162,7 @@ def main(rank, world_size, config):
|
||||||
vocab_ids.extend([id]*len(mentions))
|
vocab_ids.extend([id]*len(mentions))
|
||||||
|
|
||||||
test_mentions, test_cuis = [], []
|
test_mentions, test_cuis = [], []
|
||||||
with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
|
with open('../biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
test_cuis.append(line.strip().split('||')[-1])
|
test_cuis.append(line.strip().split('||')[-1])
|
||||||
test_mentions.append(line.strip().split('||')[-2].lower())
|
test_mentions.append(line.strip().split('||')[-2].lower())
|
||||||
|
@ -188,8 +183,7 @@ def main(rank, world_size, config):
|
||||||
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
||||||
|
|
||||||
best = 0
|
best = 0
|
||||||
if rank == 0:
|
best_res = []
|
||||||
logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' )
|
|
||||||
|
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(epochs)):
|
||||||
trainDataLoader.sampler.set_epoch(epoch)
|
trainDataLoader.sampler.set_epoch(epoch)
|
||||||
|
@ -197,11 +191,18 @@ def main(rank, world_size, config):
|
||||||
# if rank == 0 and epoch % 2 == 0:
|
# if rank == 0 and epoch % 2 == 0:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
|
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
|
||||||
logging.info(f'{epoch} {res}')
|
|
||||||
if res[0] > best:
|
if res[0] > best:
|
||||||
best = res[0]
|
best = res[0]
|
||||||
|
best_res = res
|
||||||
save_checkpoint(ddp_model.module, res, epoch, dataName)
|
save_checkpoint(ddp_model.module, res, epoch, dataName)
|
||||||
|
with open("biomedical_results/output.txt", "a") as f:
|
||||||
|
print('new best ----', file=f)
|
||||||
|
for idx,n in enumerate([1,3,5,10]):
|
||||||
|
print(f'Top-{n:<3} accuracy: {best_res[idx]}', file=f)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
[DEFAULT]
|
||||||
|
dataName = ncbi
|
||||||
|
# dataName = bc5cdr-disease
|
||||||
|
# dataName = bc5cdr-chemical
|
||||||
|
# dataName = bc2gm
|
||||||
|
|
||||||
|
[train]
|
||||||
|
epochs = 2
|
||||||
|
batch_size = 40
|
||||||
|
lr = 1e-5
|
||||||
|
|
||||||
|
dictionary = test_dictionary.txt
|
||||||
|
# or processed_dev_oneFile Note: bc2gm has no dev split
|
||||||
|
test_set = processed_test_refined
|
||||||
|
ckt_path = ./checkpoint/
|
||||||
|
|
||||||
|
|
||||||
|
[model]
|
||||||
|
margin = 2
|
||||||
|
model_name = dmis-lab/biobert-v1.1
|
||||||
|
|
||||||
|
[eval]
|
||||||
|
batch_size = 200
|
||||||
|
|
||||||
|
[data]
|
||||||
|
group_size = 4
|
||||||
|
shuffle = True
|
|
@ -0,0 +1,84 @@
|
||||||
|
# %%
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoModel
|
||||||
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||||
|
# split entity mentions into groups
|
||||||
|
entity_sets = []
|
||||||
|
# anchor means to use the entity_name as an anchor
|
||||||
|
if anchor:
|
||||||
|
# entity_id_mentions is a list of dicts, each dict is an id and a list of mentions
|
||||||
|
for id, mentions in entity_id_mentions.items():
|
||||||
|
# shuffle the mentions
|
||||||
|
random.shuffle(mentions)
|
||||||
|
# make batches of at most group size - 10
|
||||||
|
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||||
|
# the first element is always the entity_name
|
||||||
|
# this is why we use (group_size - 1)
|
||||||
|
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||||
|
entity_sets.extend(anchor_positive)
|
||||||
|
else:
|
||||||
|
# in this case, there is no entity_name in each group
|
||||||
|
for id, mentions in entity_id_mentions.items():
|
||||||
|
group = list(set([entity_id_name[id]] + mentions))
|
||||||
|
random.shuffle(group)
|
||||||
|
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||||
|
entity_sets.extend(positives)
|
||||||
|
return entity_sets
|
||||||
|
|
||||||
|
# the batch generator selects batch_size entries from the "data"
|
||||||
|
# but actually the "data" is a list of 10 items or less
|
||||||
|
def batchGenerator(data, batch_size):
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i+batch_size]
|
||||||
|
x, y = [], []
|
||||||
|
for t in batch:
|
||||||
|
x.extend(t[0]) # this is the list of mentions
|
||||||
|
y.extend([t[1]]*len(t[0])) # this multiplies a single label by the number of mentions
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||||
|
entities = json.load(file)
|
||||||
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||||
|
|
||||||
|
with open('../esAppMod/train.json', 'r') as file:
|
||||||
|
train = json.load(file)
|
||||||
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||||
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
num_sample_per_class = 10 # samples in each group
|
||||||
|
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||||
|
margin = 2
|
||||||
|
epochs = 200
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||||
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
random.shuffle(data)
|
||||||
|
|
|
@ -0,0 +1,227 @@
|
||||||
|
# this code performs dataloading with text augmentation
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# PARAMETERS
|
||||||
|
SAMPLES=5
|
||||||
|
|
||||||
|
# %%
|
||||||
|
###################################################
|
||||||
|
# import code
|
||||||
|
# import training file
|
||||||
|
data_path = '../esAppMod_data_import/train.csv'
|
||||||
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
# rather than use pattern, we use the real thing and property
|
||||||
|
entity_ids = df['entity_id'].to_list()
|
||||||
|
target_id_list = sorted(list(set(entity_ids)))
|
||||||
|
|
||||||
|
id2label = {}
|
||||||
|
label2id = {}
|
||||||
|
for idx, val in enumerate(target_id_list):
|
||||||
|
id2label[idx] = val
|
||||||
|
label2id[val] = idx
|
||||||
|
|
||||||
|
df["training_id"] = df["entity_id"].map(label2id)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
##############################################################
|
||||||
|
# augmentation code
|
||||||
|
|
||||||
|
# basic preprocessing
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_text(text, prob=0.2):
|
||||||
|
if random.random() < prob:
|
||||||
|
words = text.split() # Split the input into words
|
||||||
|
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||||
|
random.shuffle(shuffled) # Randomly shuffle the words
|
||||||
|
shuffled_text = " ".join(shuffled) # Join the words back into a string
|
||||||
|
else:
|
||||||
|
shuffled_text = text
|
||||||
|
|
||||||
|
return shuffled_text
|
||||||
|
|
||||||
|
|
||||||
|
def corrupt_word(word):
|
||||||
|
"""Corrupt a single word using random corruption techniques."""
|
||||||
|
if len(word) <= 1: # Skip corruption for single-character words
|
||||||
|
return word
|
||||||
|
|
||||||
|
corruption_type = random.choice(["delete", "swap"])
|
||||||
|
|
||||||
|
if corruption_type == "delete":
|
||||||
|
# Randomly delete a character
|
||||||
|
idx = random.randint(0, len(word) - 1)
|
||||||
|
word = word[:idx] + word[idx + 1:]
|
||||||
|
|
||||||
|
elif corruption_type == "swap":
|
||||||
|
# Swap two adjacent characters
|
||||||
|
if len(word) > 1:
|
||||||
|
idx = random.randint(0, len(word) - 2)
|
||||||
|
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||||
|
|
||||||
|
|
||||||
|
return word
|
||||||
|
|
||||||
|
def corrupt_text(sentence, prob=0.05):
|
||||||
|
"""Corrupt each word in the string with a given probability."""
|
||||||
|
words = sentence.split()
|
||||||
|
corrupted_words = [
|
||||||
|
corrupt_word(word) if random.random() < prob else word
|
||||||
|
for word in words
|
||||||
|
]
|
||||||
|
return " ".join(corrupted_words)
|
||||||
|
|
||||||
|
def strip_nonalphanumerics(desc, prob=0.5):
|
||||||
|
desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces
|
||||||
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def augment(row):
|
||||||
|
"""
|
||||||
|
function to augment "mention" string input
|
||||||
|
returns the string input with slight variation
|
||||||
|
"""
|
||||||
|
desc = row['mention']
|
||||||
|
# we always apply preprocess
|
||||||
|
desc = preprocess_text(desc)
|
||||||
|
|
||||||
|
desc = shuffle_text(desc, prob=1.0)
|
||||||
|
desc = corrupt_text(desc, prob=1.0)
|
||||||
|
desc = strip_nonalphanumerics(desc, prob=0.5)
|
||||||
|
|
||||||
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# custom dataset
|
||||||
|
# we want to sample n samples from each class
|
||||||
|
# sample_size refers to the number of samples per class
|
||||||
|
def sample_from_df(df, sample_size_per_class=5):
|
||||||
|
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
|
||||||
|
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||||
|
.reset_index(drop=True))
|
||||||
|
|
||||||
|
return sampled_df
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class DynamicDataset(Dataset):
|
||||||
|
def __init__(self, df, sample_size_per_class):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
|
||||||
|
sample_size_per_class (int): Number of samples to draw per class for each epoch.
|
||||||
|
"""
|
||||||
|
self.df = df
|
||||||
|
self.sample_size_per_class = sample_size_per_class
|
||||||
|
self.current_data = None
|
||||||
|
self.regenerate_data() # Generate the initial dataset
|
||||||
|
|
||||||
|
def regenerate_data(self):
|
||||||
|
"""
|
||||||
|
Generate a new sampled dataset for the current epoch.
|
||||||
|
|
||||||
|
dynamic callback function to regenerate data each time we call this
|
||||||
|
method, it updates the current_data we can:
|
||||||
|
|
||||||
|
- re-sample the dataframe for a new set of n_samples
|
||||||
|
- generate fresh augmentations this effectively
|
||||||
|
|
||||||
|
This allows us to re-sample and re-augment at the start of each epoch
|
||||||
|
"""
|
||||||
|
# Sample `sample_size_per_class` rows per class
|
||||||
|
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
|
||||||
|
|
||||||
|
# Store the tokenized data with labels
|
||||||
|
self.current_data = sampled_df
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.current_data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# do the transform here
|
||||||
|
row = self.current_data.iloc[idx].to_dict()
|
||||||
|
|
||||||
|
# perform text augmentation here
|
||||||
|
# independent function calls might introduce changes
|
||||||
|
mention_0 = augment(row)
|
||||||
|
mention_1 = augment(row)
|
||||||
|
return {
|
||||||
|
'training_id': row['training_id'],
|
||||||
|
'mention_0': mention_0,
|
||||||
|
'mention_1': mention_1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def custom_collate_fn(batch, tokenizer):
|
||||||
|
# batch is just a list of dictionaries
|
||||||
|
label_list = [item['training_id'] for item in batch]
|
||||||
|
mention_0_list = [item['mention_0'] for item in batch]
|
||||||
|
mention_1_list = [item['mention_1'] for item in batch]
|
||||||
|
|
||||||
|
# we can do the tokenization here
|
||||||
|
tokenized_batch_0 = tokenizer(
|
||||||
|
mention_0_list,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_batch_1 = tokenizer(
|
||||||
|
mention_1_list,
|
||||||
|
truncation=True,
|
||||||
|
padding=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
label_list = torch.tensor(label_list)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids_0': tokenized_batch_0['input_ids'],
|
||||||
|
'attention_mask_0': tokenized_batch_0['attention_mask'],
|
||||||
|
'input_ids_1': tokenized_batch_1['input_ids'],
|
||||||
|
'attention_mask_1': tokenized_batch_1['attention_mask'],
|
||||||
|
'labels': label_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
|
||||||
|
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=8,
|
||||||
|
collate_fn=custom_collate_fn_with_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
next(iter(dataloader))
|
||||||
|
# %%
|
|
@ -0,0 +1,5 @@
|
||||||
|
Top-1 accuracy: 0.01845018450184502
|
||||||
|
Top-3 accuracy: 0.02870028700287003
|
||||||
|
Top-5 accuracy: 0.03936039360393604
|
||||||
|
Top-10 accuracy: 0.08200082000820008
|
||||||
|
0.0 0.7957323
|
|
@ -0,0 +1,387 @@
|
||||||
|
# %%
|
||||||
|
# %%
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
# from datasets import load_from_disk
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
os.environ['NCCL_P2P_DISABLE'] = '1'
|
||||||
|
os.environ['NCCL_IB_DISABLE'] = '1'
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModel,
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
Trainer,
|
||||||
|
EarlyStoppingCallback,
|
||||||
|
TrainingArguments,
|
||||||
|
TrainerCallback
|
||||||
|
)
|
||||||
|
import evaluate
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from functools import partial
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataload import DynamicDataset, custom_collate_fn
|
||||||
|
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
"""
|
||||||
|
Set the random seed for reproducibility.
|
||||||
|
"""
|
||||||
|
random.seed(seed) # Python random module
|
||||||
|
np.random.seed(seed) # NumPy random
|
||||||
|
torch.manual_seed(seed) # PyTorch CPU
|
||||||
|
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||||
|
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||||
|
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||||
|
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
|
||||||
|
SAMPLES=40
|
||||||
|
BATCH_SIZE=256
|
||||||
|
|
||||||
|
# %%
|
||||||
|
###################################################
|
||||||
|
# import code
|
||||||
|
# import training file
|
||||||
|
data_path = '../esAppMod_data_import/train.csv'
|
||||||
|
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||||
|
# rather than use pattern, we use the real thing and property
|
||||||
|
entity_ids = df['entity_id'].to_list()
|
||||||
|
target_id_list = sorted(list(set(entity_ids)))
|
||||||
|
|
||||||
|
id2label = {}
|
||||||
|
label2id = {}
|
||||||
|
for idx, val in enumerate(target_id_list):
|
||||||
|
id2label[idx] = val
|
||||||
|
label2id[val] = idx
|
||||||
|
|
||||||
|
df["training_id"] = df["entity_id"].map(label2id)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# make our dataset and dataloader
|
||||||
|
# MODEL_NAME = "distilbert/distilbert-base-uncased"
|
||||||
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=False)
|
||||||
|
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
|
||||||
|
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=custom_collate_fn_with_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# enable BERT with projection layer
|
||||||
|
|
||||||
|
class VICRegProjection(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||||
|
super(VICRegProjection, self).__init__()
|
||||||
|
self.projection = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.projection(x)
|
||||||
|
|
||||||
|
|
||||||
|
class BertWithVICReg(nn.Module):
|
||||||
|
def __init__(self, bert_model, projection_dim=256):
|
||||||
|
super(BertWithVICReg, self).__init__()
|
||||||
|
self.bert = bert_model
|
||||||
|
hidden_size = bert_model.config.hidden_size
|
||||||
|
self.projection = VICRegProjection(input_dim=hidden_size, hidden_dim=hidden_size, output_dim=projection_dim)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None):
|
||||||
|
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
||||||
|
pooled_output = outputs.last_hidden_state[:,0,:]
|
||||||
|
projected_embeddings = self.projection(pooled_output)
|
||||||
|
return projected_embeddings
|
||||||
|
|
||||||
|
#######################################################
|
||||||
|
# %%
|
||||||
|
@dataclass
|
||||||
|
class Hyperparameters:
|
||||||
|
loss_constant_factor: float = 1
|
||||||
|
invariance_loss_weight: float = 25.0
|
||||||
|
variance_loss_weight: float = 25.0
|
||||||
|
covariance_loss_weight: float = 1.0
|
||||||
|
variance_loss_epsilon: float = 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
# compute vicreg loss
|
||||||
|
def get_vicreg_loss(z_a, z_b, hparams):
|
||||||
|
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
|
||||||
|
|
||||||
|
# invariance loss
|
||||||
|
loss_inv = F.mse_loss(z_a, z_b)
|
||||||
|
|
||||||
|
# variance loss
|
||||||
|
std_z_a = torch.sqrt(z_a.var(dim=0) + hparams.variance_loss_epsilon)
|
||||||
|
std_z_b = torch.sqrt(z_b.var(dim=0) + hparams.variance_loss_epsilon)
|
||||||
|
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
|
||||||
|
loss_v_b = torch.mean(F.relu(1 - std_z_b))
|
||||||
|
loss_var = loss_v_a + loss_v_b
|
||||||
|
|
||||||
|
# covariance loss
|
||||||
|
N, D = z_a.shape
|
||||||
|
z_a = z_a - z_a.mean(dim=0)
|
||||||
|
z_b = z_b - z_b.mean(dim=0)
|
||||||
|
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
|
||||||
|
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
|
||||||
|
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
|
||||||
|
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
|
||||||
|
loss_cov = loss_c_a + loss_c_b
|
||||||
|
|
||||||
|
weighted_inv = loss_inv * hparams.invariance_loss_weight
|
||||||
|
weighted_var = loss_var * hparams.variance_loss_weight
|
||||||
|
weighted_cov = loss_cov * hparams.covariance_loss_weight
|
||||||
|
|
||||||
|
loss = weighted_inv + weighted_var + weighted_cov
|
||||||
|
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
#
|
||||||
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||||
|
# bert_hidden_size = bert_model.config.hidden_size
|
||||||
|
# projection_model = VICRegProjection(
|
||||||
|
# input_dim=bert_hidden_size,
|
||||||
|
# hidden_dim=bert_hidden_size,
|
||||||
|
# output_dim=256
|
||||||
|
# )
|
||||||
|
# need to allocate individual component of the model
|
||||||
|
# bert_model.to(DEVICE)
|
||||||
|
# projection_model.to(DEVICE)
|
||||||
|
model = BertWithVICReg(bert_model, projection_dim=128)
|
||||||
|
model.to(DEVICE)
|
||||||
|
|
||||||
|
# params = list(bert_model.parameters()) + list(projection_model.parameters())
|
||||||
|
params = model.parameters()
|
||||||
|
optimizer = torch.optim.AdamW(params, lr=5e-6)
|
||||||
|
hparams = Hyperparameters()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# batch = next(iter(dataloader))
|
||||||
|
# input_ids_0 = batch['input_ids_0'].to(DEVICE)
|
||||||
|
# attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # outputs from reprojection layer
|
||||||
|
# bert_output = model(
|
||||||
|
# input_ids=input_ids_0,
|
||||||
|
# attention_mask=attention_mask_0
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# parameters
|
||||||
|
epochs = 80
|
||||||
|
|
||||||
|
for epoch in tqdm(range(epochs)):
|
||||||
|
dataset.regenerate_data()
|
||||||
|
for batch in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# compute cls 0
|
||||||
|
input_ids_0 = batch['input_ids_0'].to(DEVICE)
|
||||||
|
attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
|
||||||
|
# outputs from reprojection layer
|
||||||
|
outputs_0 = model(
|
||||||
|
input_ids=input_ids_0,
|
||||||
|
attention_mask=attention_mask_0
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute cls 1
|
||||||
|
input_ids_1 = batch['input_ids_1'].to(DEVICE)
|
||||||
|
attention_mask_1 = batch['attention_mask_1'].to(DEVICE)
|
||||||
|
# outputs from reprojection layer
|
||||||
|
outputs_1 = model(
|
||||||
|
input_ids=input_ids_1,
|
||||||
|
attention_mask=attention_mask_1
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = get_vicreg_loss(outputs_0, outputs_1, hparams=hparams)
|
||||||
|
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# print(epoch, loss)
|
||||||
|
losses.append(loss)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
print(loss.detach().item())
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
torch.save(model.state_dict(), './checkpoint/simple.pt')
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# %%
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
state_dict = torch.load('./checkpoint/simple.pt')
|
||||||
|
model = BertWithVICReg(bert_model, projection_dim=256)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 2: Load the state dictionary
|
||||||
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||||
|
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||||
|
# model = model.bert
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Step 3: Apply the state dictionary to the model
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||||
|
entities = json.load(file)
|
||||||
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||||
|
|
||||||
|
with open('../esAppMod/train.json', 'r') as file:
|
||||||
|
train = json.load(file)
|
||||||
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||||
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def preprocess_text(text):
|
||||||
|
# 1. Make all uppercase
|
||||||
|
text = text.lower()
|
||||||
|
|
||||||
|
# standardize spacing
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
with open('../esAppMod/infer.json', 'r') as file:
|
||||||
|
test = json.load(file)
|
||||||
|
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||||
|
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||||
|
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||||
|
train_entities = [preprocess_text(element) for element in train_entities]
|
||||||
|
|
||||||
|
def batch_list(data, batch_size):
|
||||||
|
"""Yield successive n-sized chunks from data."""
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
yield data[i:i + batch_size]
|
||||||
|
|
||||||
|
batches = batch_list(train_entities, 64)
|
||||||
|
|
||||||
|
embedding_list = []
|
||||||
|
for batch in batches:
|
||||||
|
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||||
|
outputs = model(
|
||||||
|
input_ids=inputs['input_ids'].to(DEVICE),
|
||||||
|
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||||
|
)
|
||||||
|
# output = outputs.last_hidden_state[:,0,:]
|
||||||
|
outputs = outputs.detach().cpu().numpy()
|
||||||
|
embedding_list.append(outputs)
|
||||||
|
|
||||||
|
cls = np.concatenate(embedding_list)
|
||||||
|
# %%
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
batches = batch_list(x_test, 64)
|
||||||
|
|
||||||
|
embedding_list = []
|
||||||
|
for batch in batches:
|
||||||
|
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||||
|
outputs = model(
|
||||||
|
input_ids=inputs['input_ids'].to(DEVICE),
|
||||||
|
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||||
|
)
|
||||||
|
# output = outputs.last_hidden_state[:,0,:]
|
||||||
|
outputs = outputs.detach().cpu().numpy()
|
||||||
|
embedding_list.append(outputs)
|
||||||
|
|
||||||
|
cls_test = np.concatenate(embedding_list)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||||
|
n_neighbors = [1, 3, 5, 10]
|
||||||
|
|
||||||
|
for n in n_neighbors:
|
||||||
|
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||||
|
num = 0
|
||||||
|
for a,b in zip(y_test, indices):
|
||||||
|
b = [labels[i] for i in b]
|
||||||
|
if a in b:
|
||||||
|
num += 1
|
||||||
|
print(f'Top-{n:<3} accuracy: {num / len(y_test)}')
|
||||||
|
print(np.min(distances), np.max(distances))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# with open("results/output.txt", "w") as f:
|
||||||
|
# for n in n_neighbors:
|
||||||
|
# distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||||
|
# num = 0
|
||||||
|
# for a,b in zip(y_test, indices):
|
||||||
|
# b = [labels[i] for i in b]
|
||||||
|
# if a in b:
|
||||||
|
# num += 1
|
||||||
|
# print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||||
|
# print(np.min(distances), np.max(distances), file=f)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Reduce dimensions with t-SNE
|
||||||
|
tsne = TSNE(n_components=2, random_state=42)
|
||||||
|
embeddings= cls
|
||||||
|
embeddings_reduced = tsne.fit_transform(embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
scatter = plt.scatter(embeddings_reduced[:, 0], embeddings_reduced[:, 1], c=labels, cmap='viridis', alpha=0.6)
|
||||||
|
plt.colorbar(scatter)
|
||||||
|
plt.xlabel('Component 1')
|
||||||
|
plt.ylabel('Component 2')
|
||||||
|
plt.title('Visualization of Embeddings')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# %%
|
Loading…
Reference in New Issue