triplet loss with classification as a regularizer
- best record of 82.57 for esAppMod
This commit is contained in:
parent
94ee7beba7
commit
ac340f6fd2
|
@ -0,0 +1,2 @@
|
|||
__pycache__
|
||||
checkpoint
|
|
@ -0,0 +1,256 @@
|
|||
# %%
|
||||
|
||||
# from datasets import load_from_disk
|
||||
import os
|
||||
import glob
|
||||
|
||||
os.environ['NCCL_P2P_DISABLE'] = '1'
|
||||
os.environ['NCCL_IB_DISABLE'] = '1'
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModel,
|
||||
DataCollatorWithPadding,
|
||||
)
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# import matplotlib.pyplot as plt
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
|
||||
BATCH_SIZE = 256
|
||||
|
||||
# %%
|
||||
# construct the target id list
|
||||
# data_path = '../../../esAppMod_data_import/train.csv'
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = train_df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
|
||||
# %%
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
|
||||
# introduce pre-processing functions
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# Substitute digits with '#'
|
||||
# text = re.sub(r'\d+', '#', text)
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
|
||||
# outputs a list of dictionaries
|
||||
# processes dataframe into lists of dictionaries
|
||||
# each element maps input to output
|
||||
# input: tag_description
|
||||
# output: class label
|
||||
def process_df_to_dict(df):
|
||||
output_list = []
|
||||
for _, row in df.iterrows():
|
||||
desc = row['mention']
|
||||
desc = preprocess_text(desc)
|
||||
index = row['entity_id']
|
||||
element = {
|
||||
'text' : desc,
|
||||
'label': label2id[index], # ensure labels starts from 0
|
||||
}
|
||||
output_list.append(element)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
def create_dataset():
|
||||
# train
|
||||
data_path = '../esAppMod_data_import/test.csv'
|
||||
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
|
||||
|
||||
# combined_data = DatasetDict({
|
||||
# 'train': Dataset.from_list(process_df_to_dict(train_df)),
|
||||
# })
|
||||
return Dataset.from_list(process_df_to_dict(test_df))
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
def test():
|
||||
|
||||
test_dataset = create_dataset()
|
||||
|
||||
# prepare tokenizer
|
||||
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
# tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
# %%
|
||||
# compute max token length
|
||||
max_length = 0
|
||||
for sample in test_dataset['text']:
|
||||
# Tokenize the sample and get the length
|
||||
input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"]
|
||||
length = len(input_ids)
|
||||
|
||||
# Update max_length if this sample is longer
|
||||
if length > max_length:
|
||||
max_length = length
|
||||
|
||||
print(max_length)
|
||||
|
||||
# %%
|
||||
|
||||
max_length = 128
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
def preprocess_function(example):
|
||||
input = example['text']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
# max_length=max_length,
|
||||
# padding='max_length'
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
# map maps function to each "row" in the dataset
|
||||
# aka the data in the immediate nesting
|
||||
datasets = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=8,
|
||||
remove_columns="text",
|
||||
)
|
||||
|
||||
|
||||
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
state_dict = torch.load('./checkpoint/classification.pt')
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
pred_labels = []
|
||||
actual_labels = []
|
||||
|
||||
|
||||
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
|
||||
for batch in tqdm(dataloader):
|
||||
# Inference in batches
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
# save labels too
|
||||
actual_labels.extend(batch['labels'])
|
||||
|
||||
|
||||
# Move to GPU if available
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
cls, logits = model(
|
||||
input_ids,
|
||||
attention_mask)
|
||||
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||
pred_labels.extend(predicted_class_ids)
|
||||
|
||||
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||
|
||||
|
||||
# %%
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
|
||||
y_true = actual_labels
|
||||
y_pred = pred_labels
|
||||
|
||||
# Compute metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
average_parameter = 'weighted'
|
||||
zero_division_parameter = 0
|
||||
f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
|
||||
with open("results/output.txt", "a") as f:
|
||||
|
||||
print('*' * 80, file=f)
|
||||
# Print the results
|
||||
print(f'Accuracy: {accuracy:.5f}', file=f)
|
||||
print(f'F1 Score: {f1:.5f}', file=f)
|
||||
print(f'Precision: {precision:.5f}', file=f)
|
||||
print(f'Recall: {recall:.5f}', file=f)
|
||||
|
||||
# export result
|
||||
label_list = [id2label[id] for id in pred_labels]
|
||||
df = pd.DataFrame({
|
||||
'class_prediction': pd.Series(label_list)
|
||||
})
|
||||
|
||||
# we can save the t5 generation output here
|
||||
df.to_csv(f"results/classify.csv", index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# reset file before writing to it
|
||||
with open("results/output.txt", "w") as f:
|
||||
print('', file=f)
|
||||
test()
|
|
@ -0,0 +1,124 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
state_dict = torch.load('./checkpoint/classification.pt')
|
||||
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(params_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,276 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, outputs, cls, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch % 5 == 0:
|
||||
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
||||
# %%
|
|
@ -0,0 +1,305 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=2
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
# %%
|
|
@ -0,0 +1,186 @@
|
|||
# stardard functionalities for computing triplet loss, borrow code from
|
||||
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
def _pairwise_distances(embeddings, squared=False):
|
||||
"""Compute the 2D matrix of distances between all the embeddings.
|
||||
Args:
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||
"""
|
||||
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
|
||||
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||
# shape (batch_size,)
|
||||
square_norm = torch.diag(dot_product)
|
||||
|
||||
# Compute the pairwise distance matrix as we have:
|
||||
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||
# shape (batch_size, batch_size)
|
||||
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||
|
||||
# Because of computation errors, some distances might be negative so we put everything >= 0.0
|
||||
distances[distances < 0] = 0
|
||||
|
||||
if not squared:
|
||||
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||
# we need to add a small epsilon where distances == 0.0
|
||||
mask = distances.eq(0).float()
|
||||
distances = distances + mask * 1e-16
|
||||
|
||||
distances = (1.0 -mask) * torch.sqrt(distances)
|
||||
|
||||
return distances
|
||||
|
||||
def _get_triplet_mask(labels):
|
||||
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
||||
A triplet (i, j, k) is valid if:
|
||||
- i, j, k are distinct
|
||||
- labels[i] == labels[j] and labels[i] != labels[k]
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
"""
|
||||
# Check that i, j and k are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
|
||||
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||
|
||||
|
||||
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
i_equal_j = label_equal.unsqueeze(2)
|
||||
i_equal_k = label_equal.unsqueeze(1)
|
||||
|
||||
valid_labels = ~i_equal_k & i_equal_j
|
||||
|
||||
return valid_labels & distinct_indices
|
||||
|
||||
|
||||
def _get_anchor_positive_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check that i and j are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
|
||||
# Check if labels[i] == labels[j]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
|
||||
return labels_equal & indices_not_equal
|
||||
|
||||
|
||||
def _get_anchor_negative_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check if labels[i] != labels[k]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
|
||||
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
||||
|
||||
|
||||
# Cell
|
||||
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist + margin
|
||||
tl = F.relu(tl)
|
||||
triplet_loss = tl.mean()
|
||||
|
||||
return triplet_loss
|
||||
|
||||
# Cell
|
||||
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||
|
||||
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,5 @@
|
|||
Top-1 accuracy: 0.8257482574825749
|
||||
Top-3 accuracy: 0.9106191061910619
|
||||
Top-5 accuracy: 0.9261992619926199
|
||||
Top-10 accuracy: 0.942189421894219
|
||||
0.0 0.82121104
|
|
@ -0,0 +1,227 @@
|
|||
# this code performs dataloading with text augmentation
|
||||
|
||||
# %%
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
)
|
||||
from functools import partial
|
||||
import re
|
||||
import random
|
||||
|
||||
# %%
|
||||
# PARAMETERS
|
||||
SAMPLES=5
|
||||
|
||||
# %%
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def shuffle_text(text, prob=0.2):
|
||||
if random.random() < prob:
|
||||
words = text.split() # Split the input into words
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_text = " ".join(shuffled) # Join the words back into a string
|
||||
else:
|
||||
shuffled_text = text
|
||||
|
||||
return shuffled_text
|
||||
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_text(sentence, prob=0.05):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < prob else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
def strip_nonalphanumerics(desc, prob=0.5):
|
||||
desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces
|
||||
return desc
|
||||
|
||||
|
||||
# %%
|
||||
def augment(row):
|
||||
"""
|
||||
function to augment "mention" string input
|
||||
returns the string input with slight variation
|
||||
"""
|
||||
desc = row['mention']
|
||||
# we always apply preprocess
|
||||
desc = preprocess_text(desc)
|
||||
|
||||
desc = shuffle_text(desc, prob=0.5)
|
||||
desc = corrupt_text(desc, prob=0.5)
|
||||
desc = strip_nonalphanumerics(desc, prob=0.5)
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# custom dataset
|
||||
# we want to sample n samples from each class
|
||||
# sample_size refers to the number of samples per class
|
||||
def sample_from_df(df, sample_size_per_class=5):
|
||||
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
|
||||
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
.reset_index(drop=True))
|
||||
|
||||
return sampled_df
|
||||
|
||||
|
||||
# %%
|
||||
class DynamicDataset(Dataset):
|
||||
def __init__(self, df, sample_size_per_class):
|
||||
"""
|
||||
Args:
|
||||
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
|
||||
sample_size_per_class (int): Number of samples to draw per class for each epoch.
|
||||
"""
|
||||
self.df = df
|
||||
self.sample_size_per_class = sample_size_per_class
|
||||
self.current_data = None
|
||||
self.regenerate_data() # Generate the initial dataset
|
||||
|
||||
def regenerate_data(self):
|
||||
"""
|
||||
Generate a new sampled dataset for the current epoch.
|
||||
|
||||
dynamic callback function to regenerate data each time we call this
|
||||
method, it updates the current_data we can:
|
||||
|
||||
- re-sample the dataframe for a new set of n_samples
|
||||
- generate fresh augmentations this effectively
|
||||
|
||||
This allows us to re-sample and re-augment at the start of each epoch
|
||||
"""
|
||||
# Sample `sample_size_per_class` rows per class
|
||||
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
|
||||
|
||||
# Store the tokenized data with labels
|
||||
self.current_data = sampled_df
|
||||
|
||||
def __len__(self):
|
||||
return len(self.current_data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# do the transform here
|
||||
row = self.current_data.iloc[idx].to_dict()
|
||||
|
||||
# perform text augmentation here
|
||||
# independent function calls might introduce changes
|
||||
mention_0 = augment(row)
|
||||
mention_1 = augment(row)
|
||||
return {
|
||||
'training_id': row['training_id'],
|
||||
'mention_0': mention_0,
|
||||
'mention_1': mention_1,
|
||||
}
|
||||
|
||||
|
||||
# %%
|
||||
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
|
||||
dataset[0]
|
||||
|
||||
|
||||
# %%
|
||||
def custom_collate_fn(batch, tokenizer):
|
||||
# batch is just a list of dictionaries
|
||||
label_list = [item['training_id'] for item in batch]
|
||||
mention_0_list = [item['mention_0'] for item in batch]
|
||||
mention_1_list = [item['mention_1'] for item in batch]
|
||||
|
||||
# we can do the tokenization here
|
||||
tokenized_batch_0 = tokenizer(
|
||||
mention_0_list,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
tokenized_batch_1 = tokenizer(
|
||||
mention_1_list,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
|
||||
label_list = torch.tensor(label_list)
|
||||
|
||||
return {
|
||||
'input_ids_0': tokenized_batch_0['input_ids'],
|
||||
'attention_mask_0': tokenized_batch_0['attention_mask'],
|
||||
'input_ids_1': tokenized_batch_1['input_ids'],
|
||||
'attention_mask_1': tokenized_batch_1['attention_mask'],
|
||||
'labels': label_list,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
|
||||
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=8,
|
||||
collate_fn=custom_collate_fn_with_tokenizer
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
next(iter(dataloader))
|
||||
# %%
|
|
@ -0,0 +1,12 @@
|
|||
Top-1 accuracy: 0.7107843137254902
|
||||
Top-3 accuracy: 0.7990196078431373
|
||||
Top-5 accuracy: 0.8137254901960784
|
||||
Top-10 accuracy: 0.8529411764705882
|
||||
Top-1 accuracy: 0.6862745098039216
|
||||
Top-3 accuracy: 0.7696078431372549
|
||||
Top-5 accuracy: 0.7990196078431373
|
||||
Top-10 accuracy: 0.8480392156862745
|
||||
Top-1 accuracy: 0.696078431372549
|
||||
Top-3 accuracy: 0.7892156862745098
|
||||
Top-5 accuracy: 0.8088235294117647
|
||||
Top-10 accuracy: 0.8382352941176471
|
|
@ -14,11 +14,10 @@ from transformers import AutoTokenizer, AutoModel
|
|||
|
||||
from data import generate_train_entity_sets
|
||||
|
||||
from tqdm import tqdm ### need to use ipywidgets==7.7.1 the newest version doesn't work
|
||||
from tqdm import tqdm
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
|
@ -73,13 +72,10 @@ def train(rank, epoch, epochs, train_dataloader, model, optimizer, tokenizer, ma
|
|||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# logging.info(f'{epoch} {len(x)} {loss.item()}')
|
||||
epoch_loss.append(loss.item())
|
||||
epoch_len.append(len(x))
|
||||
# del inputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
logging.info(f'{DEVICE}{epoch_len}')
|
||||
logging.info(f'{DEVICE}{epoch_loss}')
|
||||
|
||||
def check_label(predicted_cui: str, golden_cui: str) -> int:
|
||||
"""
|
||||
|
@ -127,8 +123,7 @@ def eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, m
|
|||
# print(np.min(distances), np.max(distances))
|
||||
|
||||
def save_checkpoint(model, res, epoch, dataName):
|
||||
logging.info(f'Saving model {epoch} {res} ')
|
||||
torch.save(model.state_dict(), './checkpoints/'+dataName+'.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/' + dataName + '.pt')
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self,MODEL_NAME):
|
||||
|
@ -146,12 +141,12 @@ def main(rank, world_size, config):
|
|||
setup(rank, world_size)
|
||||
|
||||
dataName = config['DEFAULT']['dataName']
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', filename=config['train']['ckt_path']+dataName+'.log', filemode='a', level=logging.INFO)
|
||||
|
||||
vocab = defaultdict(set)
|
||||
with open('./data/biomedical/'+dataName+'/'+config['train']['dictionary']) as f:
|
||||
with open('../biomedical/' + dataName + '/' + config['train']['dictionary']) as f:
|
||||
for line in f:
|
||||
vocab[line.strip().split('||')[0]].add(line.strip().split('||')[1].lower())
|
||||
line_list = line.strip().split('||')
|
||||
vocab[line_list[0]].add(line_list[1].lower())
|
||||
|
||||
cui_to_id, id_to_cui = {}, {}
|
||||
vocab_entity_id_mentions = {}
|
||||
|
@ -167,7 +162,7 @@ def main(rank, world_size, config):
|
|||
vocab_ids.extend([id]*len(mentions))
|
||||
|
||||
test_mentions, test_cuis = [], []
|
||||
with open('./data/biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
|
||||
with open('../biomedical/'+dataName+'/'+config['train']['test_set']+'/0.concept') as f:
|
||||
for line in f:
|
||||
test_cuis.append(line.strip().split('||')[-1])
|
||||
test_mentions.append(line.strip().split('||')[-2].lower())
|
||||
|
@ -188,8 +183,7 @@ def main(rank, world_size, config):
|
|||
ddp_model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
|
||||
|
||||
best = 0
|
||||
if rank == 0:
|
||||
logging.info(f'epochs:{epochs} group_size:{num_sample_per_class} batch_size:{batch_size} %num:1 device:{torch.cuda.get_device_name()} count:{torch.cuda.device_count()} base:{MODEL_NAME}' )
|
||||
best_res = []
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
trainDataLoader.sampler.set_epoch(epoch)
|
||||
|
@ -197,11 +191,18 @@ def main(rank, world_size, config):
|
|||
# if rank == 0 and epoch % 2 == 0:
|
||||
if rank == 0:
|
||||
res = eval(rank, vocab_mentions, vocab_ids, test_mentions, test_cuis, id_to_cui, ddp_model.module, tokenizer)
|
||||
logging.info(f'{epoch} {res}')
|
||||
if res[0] > best:
|
||||
best = res[0]
|
||||
best_res = res
|
||||
save_checkpoint(ddp_model.module, res, epoch, dataName)
|
||||
with open("biomedical_results/output.txt", "a") as f:
|
||||
print('new best ----', file=f)
|
||||
for idx,n in enumerate([1,3,5,10]):
|
||||
print(f'Top-{n:<3} accuracy: {best_res[idx]}', file=f)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
[DEFAULT]
|
||||
dataName = ncbi
|
||||
# dataName = bc5cdr-disease
|
||||
# dataName = bc5cdr-chemical
|
||||
# dataName = bc2gm
|
||||
|
||||
[train]
|
||||
epochs = 2
|
||||
batch_size = 40
|
||||
lr = 1e-5
|
||||
|
||||
dictionary = test_dictionary.txt
|
||||
# or processed_dev_oneFile Note: bc2gm has no dev split
|
||||
test_set = processed_test_refined
|
||||
ckt_path = ./checkpoint/
|
||||
|
||||
|
||||
[model]
|
||||
margin = 2
|
||||
model_name = dmis-lab/biobert-v1.1
|
||||
|
||||
[eval]
|
||||
batch_size = 200
|
||||
|
||||
[data]
|
||||
group_size = 4
|
||||
shuffle = True
|
|
@ -0,0 +1,84 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
entity_sets = []
|
||||
# anchor means to use the entity_name as an anchor
|
||||
if anchor:
|
||||
# entity_id_mentions is a list of dicts, each dict is an id and a list of mentions
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
# shuffle the mentions
|
||||
random.shuffle(mentions)
|
||||
# make batches of at most group size - 10
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
# the first element is always the entity_name
|
||||
# this is why we use (group_size - 1)
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
# in this case, there is no entity_name in each group
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
# the batch generator selects batch_size entries from the "data"
|
||||
# but actually the "data" is a list of 10 items or less
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0]) # this is the list of mentions
|
||||
y.extend([t[1]]*len(t[0])) # this multiplies a single label by the number of mentions
|
||||
yield x, y
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
|
||||
# %%
|
||||
random.shuffle(data)
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
# this code performs dataloading with text augmentation
|
||||
|
||||
# %%
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
)
|
||||
from functools import partial
|
||||
import re
|
||||
import random
|
||||
|
||||
# %%
|
||||
# PARAMETERS
|
||||
SAMPLES=5
|
||||
|
||||
# %%
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def shuffle_text(text, prob=0.2):
|
||||
if random.random() < prob:
|
||||
words = text.split() # Split the input into words
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_text = " ".join(shuffled) # Join the words back into a string
|
||||
else:
|
||||
shuffled_text = text
|
||||
|
||||
return shuffled_text
|
||||
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_text(sentence, prob=0.05):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < prob else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
def strip_nonalphanumerics(desc, prob=0.5):
|
||||
desc = re.sub(r'[^\w\s]', ' ', desc) # Retains only alphanumeric and spaces
|
||||
return desc
|
||||
|
||||
|
||||
# %%
|
||||
def augment(row):
|
||||
"""
|
||||
function to augment "mention" string input
|
||||
returns the string input with slight variation
|
||||
"""
|
||||
desc = row['mention']
|
||||
# we always apply preprocess
|
||||
desc = preprocess_text(desc)
|
||||
|
||||
desc = shuffle_text(desc, prob=1.0)
|
||||
desc = corrupt_text(desc, prob=1.0)
|
||||
desc = strip_nonalphanumerics(desc, prob=0.5)
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# custom dataset
|
||||
# we want to sample n samples from each class
|
||||
# sample_size refers to the number of samples per class
|
||||
def sample_from_df(df, sample_size_per_class=5):
|
||||
sampled_df = (df.groupby( "training_id")[['training_id', 'mention']] # explicit give column names
|
||||
.apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
.reset_index(drop=True))
|
||||
|
||||
return sampled_df
|
||||
|
||||
|
||||
# %%
|
||||
class DynamicDataset(Dataset):
|
||||
def __init__(self, df, sample_size_per_class):
|
||||
"""
|
||||
Args:
|
||||
df (pd.DataFrame): Original DataFrame with class (id) and data columns.
|
||||
sample_size_per_class (int): Number of samples to draw per class for each epoch.
|
||||
"""
|
||||
self.df = df
|
||||
self.sample_size_per_class = sample_size_per_class
|
||||
self.current_data = None
|
||||
self.regenerate_data() # Generate the initial dataset
|
||||
|
||||
def regenerate_data(self):
|
||||
"""
|
||||
Generate a new sampled dataset for the current epoch.
|
||||
|
||||
dynamic callback function to regenerate data each time we call this
|
||||
method, it updates the current_data we can:
|
||||
|
||||
- re-sample the dataframe for a new set of n_samples
|
||||
- generate fresh augmentations this effectively
|
||||
|
||||
This allows us to re-sample and re-augment at the start of each epoch
|
||||
"""
|
||||
# Sample `sample_size_per_class` rows per class
|
||||
sampled_df = sample_from_df(self.df, self.sample_size_per_class)
|
||||
|
||||
# Store the tokenized data with labels
|
||||
self.current_data = sampled_df
|
||||
|
||||
def __len__(self):
|
||||
return len(self.current_data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# do the transform here
|
||||
row = self.current_data.iloc[idx].to_dict()
|
||||
|
||||
# perform text augmentation here
|
||||
# independent function calls might introduce changes
|
||||
mention_0 = augment(row)
|
||||
mention_1 = augment(row)
|
||||
return {
|
||||
'training_id': row['training_id'],
|
||||
'mention_0': mention_0,
|
||||
'mention_1': mention_1,
|
||||
}
|
||||
|
||||
|
||||
# %%
|
||||
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
|
||||
dataset[0]
|
||||
|
||||
|
||||
# %%
|
||||
def custom_collate_fn(batch, tokenizer):
|
||||
# batch is just a list of dictionaries
|
||||
label_list = [item['training_id'] for item in batch]
|
||||
mention_0_list = [item['mention_0'] for item in batch]
|
||||
mention_1_list = [item['mention_1'] for item in batch]
|
||||
|
||||
# we can do the tokenization here
|
||||
tokenized_batch_0 = tokenizer(
|
||||
mention_0_list,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
tokenized_batch_1 = tokenizer(
|
||||
mention_1_list,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
|
||||
label_list = torch.tensor(label_list)
|
||||
|
||||
return {
|
||||
'input_ids_0': tokenized_batch_0['input_ids'],
|
||||
'attention_mask_0': tokenized_batch_0['attention_mask'],
|
||||
'input_ids_1': tokenized_batch_1['input_ids'],
|
||||
'attention_mask_1': tokenized_batch_1['attention_mask'],
|
||||
'labels': label_list,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
|
||||
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=8,
|
||||
collate_fn=custom_collate_fn_with_tokenizer
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
next(iter(dataloader))
|
||||
# %%
|
|
@ -0,0 +1,5 @@
|
|||
Top-1 accuracy: 0.01845018450184502
|
||||
Top-3 accuracy: 0.02870028700287003
|
||||
Top-5 accuracy: 0.03936039360393604
|
||||
Top-10 accuracy: 0.08200082000820008
|
||||
0.0 0.7957323
|
|
@ -0,0 +1,387 @@
|
|||
# %%
|
||||
# %%
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
# from datasets import load_from_disk
|
||||
import os
|
||||
import json
|
||||
|
||||
os.environ['NCCL_P2P_DISABLE'] = '1'
|
||||
os.environ['NCCL_IB_DISABLE'] = '1'
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModel,
|
||||
DataCollatorWithPadding,
|
||||
Trainer,
|
||||
EarlyStoppingCallback,
|
||||
TrainingArguments,
|
||||
TrainerCallback
|
||||
)
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from functools import partial
|
||||
import warnings
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataload import DynamicDataset, custom_collate_fn
|
||||
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
SAMPLES=40
|
||||
BATCH_SIZE=256
|
||||
|
||||
# %%
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
|
||||
# %%
|
||||
# make our dataset and dataloader
|
||||
# MODEL_NAME = "distilbert/distilbert-base-uncased"
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, clean_up_tokenization_spaces=False)
|
||||
dataset = DynamicDataset(df, sample_size_per_class=SAMPLES)
|
||||
custom_collate_fn_with_tokenizer = partial(custom_collate_fn, tokenizer=tokenizer)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
collate_fn=custom_collate_fn_with_tokenizer
|
||||
)
|
||||
|
||||
# %%
|
||||
# enable BERT with projection layer
|
||||
|
||||
class VICRegProjection(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||
super(VICRegProjection, self).__init__()
|
||||
self.projection = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.projection(x)
|
||||
|
||||
|
||||
class BertWithVICReg(nn.Module):
|
||||
def __init__(self, bert_model, projection_dim=256):
|
||||
super(BertWithVICReg, self).__init__()
|
||||
self.bert = bert_model
|
||||
hidden_size = bert_model.config.hidden_size
|
||||
self.projection = VICRegProjection(input_dim=hidden_size, hidden_dim=hidden_size, output_dim=projection_dim)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
||||
pooled_output = outputs.last_hidden_state[:,0,:]
|
||||
projected_embeddings = self.projection(pooled_output)
|
||||
return projected_embeddings
|
||||
|
||||
#######################################################
|
||||
# %%
|
||||
@dataclass
|
||||
class Hyperparameters:
|
||||
loss_constant_factor: float = 1
|
||||
invariance_loss_weight: float = 25.0
|
||||
variance_loss_weight: float = 25.0
|
||||
covariance_loss_weight: float = 1.0
|
||||
variance_loss_epsilon: float = 1e-5
|
||||
|
||||
|
||||
# compute vicreg loss
|
||||
def get_vicreg_loss(z_a, z_b, hparams):
|
||||
assert z_a.shape == z_b.shape and len(z_a.shape) == 2
|
||||
|
||||
# invariance loss
|
||||
loss_inv = F.mse_loss(z_a, z_b)
|
||||
|
||||
# variance loss
|
||||
std_z_a = torch.sqrt(z_a.var(dim=0) + hparams.variance_loss_epsilon)
|
||||
std_z_b = torch.sqrt(z_b.var(dim=0) + hparams.variance_loss_epsilon)
|
||||
loss_v_a = torch.mean(F.relu(1 - std_z_a)) # differentiable max
|
||||
loss_v_b = torch.mean(F.relu(1 - std_z_b))
|
||||
loss_var = loss_v_a + loss_v_b
|
||||
|
||||
# covariance loss
|
||||
N, D = z_a.shape
|
||||
z_a = z_a - z_a.mean(dim=0)
|
||||
z_b = z_b - z_b.mean(dim=0)
|
||||
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
|
||||
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
|
||||
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
|
||||
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
|
||||
loss_cov = loss_c_a + loss_c_b
|
||||
|
||||
weighted_inv = loss_inv * hparams.invariance_loss_weight
|
||||
weighted_var = loss_var * hparams.variance_loss_weight
|
||||
weighted_cov = loss_cov * hparams.covariance_loss_weight
|
||||
|
||||
loss = weighted_inv + weighted_var + weighted_cov
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
#
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
# bert_hidden_size = bert_model.config.hidden_size
|
||||
# projection_model = VICRegProjection(
|
||||
# input_dim=bert_hidden_size,
|
||||
# hidden_dim=bert_hidden_size,
|
||||
# output_dim=256
|
||||
# )
|
||||
# need to allocate individual component of the model
|
||||
# bert_model.to(DEVICE)
|
||||
# projection_model.to(DEVICE)
|
||||
model = BertWithVICReg(bert_model, projection_dim=128)
|
||||
model.to(DEVICE)
|
||||
|
||||
# params = list(bert_model.parameters()) + list(projection_model.parameters())
|
||||
params = model.parameters()
|
||||
optimizer = torch.optim.AdamW(params, lr=5e-6)
|
||||
hparams = Hyperparameters()
|
||||
|
||||
losses = []
|
||||
|
||||
# # %%
|
||||
# batch = next(iter(dataloader))
|
||||
# input_ids_0 = batch['input_ids_0'].to(DEVICE)
|
||||
# attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
|
||||
#
|
||||
# # %%
|
||||
# # outputs from reprojection layer
|
||||
# bert_output = model(
|
||||
# input_ids=input_ids_0,
|
||||
# attention_mask=attention_mask_0
|
||||
# )
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# parameters
|
||||
epochs = 80
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
dataset.regenerate_data()
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# compute cls 0
|
||||
input_ids_0 = batch['input_ids_0'].to(DEVICE)
|
||||
attention_mask_0 = batch['attention_mask_0'].to(DEVICE)
|
||||
# outputs from reprojection layer
|
||||
outputs_0 = model(
|
||||
input_ids=input_ids_0,
|
||||
attention_mask=attention_mask_0
|
||||
)
|
||||
|
||||
# compute cls 1
|
||||
input_ids_1 = batch['input_ids_1'].to(DEVICE)
|
||||
attention_mask_1 = batch['attention_mask_1'].to(DEVICE)
|
||||
# outputs from reprojection layer
|
||||
outputs_1 = model(
|
||||
input_ids=input_ids_1,
|
||||
attention_mask=attention_mask_1
|
||||
)
|
||||
|
||||
loss = get_vicreg_loss(outputs_0, outputs_1, hparams=hparams)
|
||||
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print(epoch, loss)
|
||||
losses.append(loss)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(loss.detach().item())
|
||||
|
||||
|
||||
# %%
|
||||
torch.save(model.state_dict(), './checkpoint/simple.pt')
|
||||
|
||||
####################################################
|
||||
# %%
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
state_dict = torch.load('./checkpoint/simple.pt')
|
||||
model = BertWithVICReg(bert_model, projection_dim=256)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# model = model.bert
|
||||
# %%
|
||||
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
# output = outputs.last_hidden_state[:,0,:]
|
||||
outputs = outputs.detach().cpu().numpy()
|
||||
embedding_list.append(outputs)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
# output = outputs.last_hidden_state[:,0,:]
|
||||
outputs = outputs.detach().cpu().numpy()
|
||||
embedding_list.append(outputs)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}')
|
||||
print(np.min(distances), np.max(distances))
|
||||
|
||||
|
||||
|
||||
# with open("results/output.txt", "w") as f:
|
||||
# for n in n_neighbors:
|
||||
# distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
# num = 0
|
||||
# for a,b in zip(y_test, indices):
|
||||
# b = [labels[i] for i in b]
|
||||
# if a in b:
|
||||
# num += 1
|
||||
# print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
# print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
import matplotlib.pyplot as plt
|
||||
# %%
|
||||
|
||||
# Reduce dimensions with t-SNE
|
||||
tsne = TSNE(n_components=2, random_state=42)
|
||||
embeddings= cls
|
||||
embeddings_reduced = tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
plt.figure(figsize=(10, 8))
|
||||
scatter = plt.scatter(embeddings_reduced[:, 0], embeddings_reduced[:, 1], c=labels, cmap='viridis', alpha=0.6)
|
||||
plt.colorbar(scatter)
|
||||
plt.xlabel('Component 1')
|
||||
plt.ylabel('Component 2')
|
||||
plt.title('Visualization of Embeddings')
|
||||
plt.show()
|
||||
|
||||
# %%
|
Loading…
Reference in New Issue