added experiments with triplet loss and augmentations
- includes experiments on character-level bert
This commit is contained in:
parent
ac340f6fd2
commit
182760b7a2
|
@ -0,0 +1,56 @@
|
|||
# %%
|
||||
import pandas as pd
|
||||
import json
|
||||
|
||||
# %%
|
||||
data_path = '../loss_comparisons_without_augmentation/results/predictions.txt'
|
||||
df = pd.read_csv(data_path, header=None)
|
||||
df = df.rename(columns={0: 'actual', 1: 'predicted'})
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
df['predicted_name'] = df['predicted'].map(all_entity_id_name)
|
||||
|
||||
# %%
|
||||
# import test file
|
||||
data_path = '../esAppMod_data_import/test.csv'
|
||||
# data_path = '../esAppMod_data_import/parent_test.csv'
|
||||
test_df = pd.read_csv(data_path)
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
df_out = pd.concat([test_df,df], axis=1)
|
||||
|
||||
# %%
|
||||
mask1 = (df['predicted'] != df['actual'])
|
||||
# %%
|
||||
|
||||
print(df_out[mask1].sort_values(by=['entity_id']).to_markdown())
|
||||
# %%
|
||||
|
||||
data_path = '../loss_comparisons_with_augmentations/results/predictions.txt'
|
||||
df2 = pd.read_csv(data_path, header=None)
|
||||
df2 = df2.rename(columns={0: 'actual', 1: 'predicted'})
|
||||
mask2 = df2['actual'] != df2['predicted']
|
||||
|
||||
|
||||
# %%
|
||||
# i want to find entries that were:
|
||||
# - correct in mask1
|
||||
# - wrong in mask2
|
||||
mask_left = ~mask1 & mask2
|
||||
|
||||
predicted_entity = df2['predicted'].map(all_entity_id_name)
|
||||
df_out = pd.concat([test_df,df2, predicted_entity], axis=1)
|
||||
print(df_out[mask_left].sort_values(by=['entity_id']).to_markdown())
|
||||
# %%
|
|
@ -0,0 +1,59 @@
|
|||
# %%
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# %%
|
||||
data_path = '../loss_comparisons_without_augmentation/top1_curves/baseline_output.txt'
|
||||
df = pd.read_csv(data_path, header=None)
|
||||
y = df[0]
|
||||
plt.plot(y)
|
||||
|
||||
# Find the max value
|
||||
max_y = np.max(y) # Max value
|
||||
max_x = np.argmax(y) # x value corresponding to the max y
|
||||
# Annotate the max value on the plot
|
||||
# plt.annotate(f'Max: {max_y:.5f}', # Text to display
|
||||
# xy=(max_x, max_y), # Point to annotate
|
||||
# xytext=(max_x+0.7, max_y-0.3), # Location of text
|
||||
# arrowprops=dict(facecolor='black',arrowstyle='->'),
|
||||
# bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='yellow'))
|
||||
|
||||
|
||||
# data_path = '../experimental/top1_curves/character_output.txt'
|
||||
# df = pd.read_csv(data_path, header=None)
|
||||
# y = df[0]
|
||||
# plt.plot(y)
|
||||
# max_y = np.max(y) # Max value
|
||||
# max_x = np.argmax(y) # x value corresponding to the max y
|
||||
# # Annotate the max value on the plot
|
||||
# plt.annotate(f'Max: {max_y:.5f}', # Text to display
|
||||
# xy=(max_x, max_y), # Point to annotate
|
||||
# xytext=(max_x+0.7, max_y-0.2), # Location of text
|
||||
# arrowprops=dict(facecolor='black',arrowstyle='->'),
|
||||
# bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='yellow'))
|
||||
|
||||
data_path = '../experimental/top1_curves/character_knn.txt'
|
||||
df = pd.read_csv(data_path, header=None)
|
||||
y = df[0]
|
||||
plt.plot(y)
|
||||
max_y = np.max(y) # Max value
|
||||
max_x = np.argmax(y) # x value corresponding to the max y
|
||||
# Annotate the max value on the plot
|
||||
plt.annotate(f'Max: {max_y:.5f}', # Text to display
|
||||
xy=(max_x, max_y), # Point to annotate
|
||||
xytext=(max_x+0.7, max_y-0.4), # Location of text
|
||||
arrowprops=dict(facecolor='black',arrowstyle='->'),
|
||||
bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='yellow'))
|
||||
|
||||
plt.ylim(0.4,1)
|
||||
|
||||
# data_path = '../loss_comparisons_with_augmentations/top1_curves/smooth_output.txt'
|
||||
# df = pd.read_csv(data_path, header=None)
|
||||
# plt.plot(df[0])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
|
@ -0,0 +1,125 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
# state_dict = torch.load('./checkpoint/classification.pt')
|
||||
state_dict = torch.load('./checkpoint/baseline.pt')
|
||||
# params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,277 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# all augmentations disabled
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, outputs, cls, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch % 5 == 0:
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
|
@ -15,8 +15,8 @@ import gc
|
|||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
|
@ -109,7 +109,9 @@ def test():
|
|||
|
||||
# prepare tokenizer
|
||||
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'distilbert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small'
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
|
@ -0,0 +1,316 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# disable augmentations
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
# y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
# if epoch < epochs / 2:
|
||||
# triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# # for training after half the time, train on hard
|
||||
# else:
|
||||
# triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss # + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
# %%
|
|
@ -163,30 +163,31 @@ def augment_data(df):
|
|||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
# all augmentations disabled
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
@ -215,7 +216,7 @@ num_sample_per_class = 10 # samples in each group
|
|||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
|
@ -269,8 +270,8 @@ for epoch in tqdm(range(epochs)):
|
|||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch % 5 == 0:
|
||||
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/siamese_simple.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
||||
|
|
|
@ -217,8 +217,8 @@ batch_size = 16 # number of groups, effective batch_size for computing triplet l
|
|||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
@ -245,12 +245,22 @@ model.train()
|
|||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# Update optimizer's learning rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
@ -294,7 +304,7 @@ for epoch in tqdm(range(epochs)):
|
|||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
state_dict = torch.load('./checkpoint/hybrid.pt')
|
||||
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(params_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,315 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/hybrid.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/hybrid.pt')
|
||||
# %%
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
|||
Top-1 accuracy: 0.8257482574825749
|
||||
Top-3 accuracy: 0.9106191061910619
|
||||
Top-5 accuracy: 0.9261992619926199
|
||||
Top-10 accuracy: 0.942189421894219
|
||||
0.0 0.82121104
|
||||
Top-1 accuracy: 0.8072980729807298
|
||||
Top-3 accuracy: 0.8946289462894629
|
||||
Top-5 accuracy: 0.9040590405904059
|
||||
Top-10 accuracy: 0.924149241492415
|
||||
0.0 0.7571934
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
# from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
import loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=3
|
||||
AMPLIFY_FACTOR=3
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
# %%
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
|
||||
|
||||
# %%
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]])
|
||||
yield x, y
|
||||
|
||||
|
||||
|
||||
# simulate 1 epoch
|
||||
y_accumulator = []
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
y_accumulator.append(y)
|
||||
|
||||
|
||||
# %%
|
||||
y_accumulator
|
||||
|
||||
# %%
|
|
@ -0,0 +1,378 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
# from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
import loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
# %%
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
|
||||
|
||||
# %%
|
||||
x, y = next(iter(batchGenerator(data, batch_size)))
|
||||
|
||||
# %%
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
# %%
|
||||
def _pairwise_distances(embeddings, squared=False):
|
||||
"""Compute the 2D matrix of distances between all the embeddings.
|
||||
Args:
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||
"""
|
||||
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
|
||||
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||
# shape (batch_size,)
|
||||
square_norm = torch.diag(dot_product)
|
||||
|
||||
# Compute the pairwise distance matrix as we have:
|
||||
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||
# shape (batch_size, batch_size)
|
||||
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||
|
||||
# Apply a lower bound to distances to ensure they are non-negative and avoid tiny negative numbers due to computation errors
|
||||
distances = torch.clamp(distances, min=0.0)
|
||||
|
||||
if not squared:
|
||||
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||
# we need to add a small epsilon where distances == 0.0
|
||||
epsilon = 1e-16
|
||||
mask = (distances < epsilon).float()
|
||||
distances = distances + mask * epsilon
|
||||
|
||||
distances = (1.0 - mask) * torch.sqrt(distances)
|
||||
|
||||
return distances
|
||||
|
||||
# %%
|
||||
embeddings = cls
|
||||
squared = False
|
||||
|
||||
# %%
|
||||
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = loss._pairwise_distances(embeddings, squared=squared) # 96x96
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2) # 96x96x1
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1) # 96x1x96
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# every (i,j) pairwise distance - every (i,k) pairwise distance
|
||||
# fixing for i, we get (i,j) - (i,k), for every j and k, which is 96x96
|
||||
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
# remember that broadcasting is repeating the other axis n-times
|
||||
# this broadcasting trick is to get every possible triple combination
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||
|
||||
# triplet_loss 96x96x96
|
||||
|
||||
# %%
|
||||
labels = y
|
||||
|
||||
# %%
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = loss._get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
# %%
|
||||
# %%
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
# %%
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
# %%
|
||||
# Check that i, j and k are distinct
|
||||
# create an identity matrix of size 96
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
|
||||
# %%
|
||||
indices_not_equal = ~indices_equal
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2) # [96,96,1]
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1) # [96,1,96]
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0) # [1,96,96]
|
||||
|
||||
|
||||
# %%
|
||||
# eliminate any combination that uses the diagonal values (aka sharing same values)
|
||||
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||
|
||||
|
||||
# %%
|
||||
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
# label_equal is a 96x96 matrix showing where 2 labels equate
|
||||
|
||||
# perform the same unsqueeze to 1 and 2 axis and broadcast to get all possible combinations
|
||||
# note that we have 96 elements, but we want all (i,j,k) combinations from these 96 elements
|
||||
i_equal_j = label_equal.unsqueeze(2)
|
||||
i_equal_k = label_equal.unsqueeze(1)
|
||||
|
||||
# ~i_equal_k means that it checks for non-equality between i and k
|
||||
# i_equal_j checks for equality between i and j
|
||||
# we want (i,j) to be the same label, (i,k) to be different labels
|
||||
valid_labels = ~i_equal_k & i_equal_j
|
||||
|
||||
# %%
|
||||
final_mask = distinct_indices & valid_labels
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
__pycache__
|
||||
checkpoint
|
||||
results
|
||||
top1_curves
|
|
@ -0,0 +1,577 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import accuracy_score
|
||||
from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
CORRUPT=0.00
|
||||
LEARNING_RATE=1e-6
|
||||
DEVICE = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
# %%
|
||||
EVAL_FILE="top1_curves/batch_output.txt"
|
||||
with open(EVAL_FILE, "w") as f:
|
||||
pass
|
||||
|
||||
EVAL_FILE_KNN="top1_curves/batch_knn.txt"
|
||||
with open(EVAL_FILE_KNN, "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
# def sample_from_df(df, sample_size_per_class=5):
|
||||
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
|
||||
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
# .reset_index(drop=True))
|
||||
#
|
||||
# return sampled_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
# %%
|
||||
# evaluation
|
||||
def run_evaluation_logit(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
pred_labels = []
|
||||
for batch in batches:
|
||||
# Inference in batches
|
||||
inputs, attn_mask = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
with torch.no_grad():
|
||||
_, logits = model(inputs, attn_mask)
|
||||
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||
pred_labels.extend(predicted_class_ids)
|
||||
|
||||
|
||||
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||
|
||||
# %%
|
||||
labels = [label2id[element] for element in y_test]
|
||||
with open(EVAL_FILE, "a") as f:
|
||||
# only compute top-1
|
||||
accuracy = accuracy_score(labels, pred_labels)
|
||||
print(f'{accuracy}', file=f)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def run_evaluation_knn(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(eval_train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs, attn_mask = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
outputs = model(inputs, attn_mask)
|
||||
output_slice = outputs[:,0,:]
|
||||
output_slice = output_slice.detach().cpu().numpy()
|
||||
embedding_list.append(output_slice)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs, attn_mask = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
outputs = model(inputs, attn_mask)
|
||||
output_slice = outputs[:,0,:]
|
||||
output_slice = output_slice.detach().cpu().numpy()
|
||||
embedding_list.append(output_slice)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels)
|
||||
|
||||
|
||||
with open(EVAL_FILE_KNN, "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [eval_labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
# %%
|
||||
class CharacterTransformer(nn.Module):
|
||||
def __init__(self, num_chars, d_model=256, nhead=4, num_encoder_layers=4):
|
||||
super(CharacterTransformer, self).__init__()
|
||||
self.char_embedding = nn.Embedding(num_chars, d_model)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def forward(self, input, attention_mask):
|
||||
# input: (batch_size, seq_len)
|
||||
embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model)
|
||||
# embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model)
|
||||
output = self.transformer_encoder(embeddings, src_key_padding_mask=attention_mask)
|
||||
# output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model)
|
||||
return output
|
||||
|
||||
class ASCIITokenizer:
|
||||
def __init__(self, pad_token='\0'):
|
||||
# Initialize the tokenizer with ASCII characters.
|
||||
# ASCII characters range from 0 to 127.
|
||||
self.char_to_id = {chr(i): i for i in range(128)}
|
||||
self.id_to_char = {i: chr(i) for i in range(128)}
|
||||
self.pad_token = pad_token
|
||||
|
||||
def encode(self, text_list):
|
||||
"""Encode a text string into a list of ASCII IDs and generate attention masks."""
|
||||
output_list = []
|
||||
max_length = 0
|
||||
# First pass to find the maximum length and encode the texts
|
||||
for text in text_list:
|
||||
text = self.pad_token + text # Prepend pad_token to each text
|
||||
output = [self.char_to_id.get(char, self.pad_token) for char in text]
|
||||
output_list.append(output)
|
||||
if len(output) > max_length:
|
||||
max_length = len(output)
|
||||
|
||||
# Second pass to pad the sequences to the maximum length and create masks
|
||||
padded_list = []
|
||||
attention_masks = []
|
||||
for output in output_list:
|
||||
# we cannot mask the first token
|
||||
attention_mask = [0] + [0] * (len(output) - 1) + [1] * (max_length - len(output)) # 1s for real tokens, 0s for padding
|
||||
output = self.pad(output, max_length)
|
||||
padded_list.append(output)
|
||||
attention_masks.append(attention_mask)
|
||||
|
||||
return torch.tensor(padded_list, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.bool)
|
||||
|
||||
|
||||
def decode(self, ids_list):
|
||||
"""Decode a list of ASCII IDs back into a text string."""
|
||||
output_list = []
|
||||
for ids in ids_list:
|
||||
output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char)
|
||||
output_list.append(output)
|
||||
return output_list
|
||||
|
||||
def pad(self, output, max_length):
|
||||
"""Pad the output list with ASCII ID for space or another padding character to the maximum length."""
|
||||
return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output))
|
||||
# %%
|
||||
tokenizer = ASCIITokenizer()
|
||||
# # Example text
|
||||
# text = ["Hello, world! This is cool", "Hello, world!"]
|
||||
# # Encode the text
|
||||
# encoded = tokenizer.encode(text)
|
||||
# print("Encoded:", encoded)
|
||||
#
|
||||
# # Decode the encoded IDs
|
||||
# decoded = tokenizer.decode(encoded.numpy())
|
||||
# print("Decoded:", decoded)
|
||||
|
||||
# %%
|
||||
# Example usage
|
||||
bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids, attention_mask)
|
||||
cls_embeddings = outputs[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# num_warmup_steps=100
|
||||
# total_steps = epochs * (1126/64)
|
||||
# scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
|
||||
|
||||
|
||||
# %%
|
||||
state_dict = torch.load('./checkpoint/pretrained_character_bert.pt')
|
||||
state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()}
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
if epoch % 10 == 0:
|
||||
augmented_df = augment_data(df)
|
||||
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs, attn_mask = tokenizer.encode(x)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
cls, logits = model(inputs, attn_mask)
|
||||
|
||||
# labels = y
|
||||
# labels = [label2id[element] for element in labels]
|
||||
# labels = torch.tensor(labels).to(DEVICE)
|
||||
|
||||
# loss = F.cross_entropy(logits, labels)
|
||||
|
||||
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
# scheduler.step()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# del x, y, outputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
epoch_loss = total_loss/batch_number
|
||||
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {epoch_loss}')
|
||||
if (epoch % 1 == 0):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation_logit(model=model, tokenizer=tokenizer)
|
||||
run_evaluation_knn(model=model.bert, tokenizer=tokenizer)
|
||||
# run evaluation on test data
|
||||
model.train()
|
||||
|
||||
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if (epoch % 100 == 0) and (epoch > 100):
|
||||
torch.save(model.state_dict(), './checkpoint/character_bert.pt')
|
||||
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/character_bert_final.pt')
|
||||
# %%
|
|
@ -0,0 +1,288 @@
|
|||
# stardard functionalities for computing triplet loss, borrow code from
|
||||
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
def _pairwise_distances(embeddings, squared=False):
|
||||
"""Compute the 2D matrix of distances between all the embeddings.
|
||||
Args:
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||
"""
|
||||
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
|
||||
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||
# shape (batch_size,)
|
||||
square_norm = torch.diag(dot_product)
|
||||
|
||||
# Compute the pairwise distance matrix as we have:
|
||||
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||
# shape (batch_size, batch_size)
|
||||
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||
|
||||
# Because of computation errors, some distances might be negative so we put everything >= 0.0
|
||||
distances[distances < 0] = 0
|
||||
|
||||
if not squared:
|
||||
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||
# we need to add a small epsilon where distances == 0.0
|
||||
mask = distances.eq(0).float()
|
||||
distances = distances + mask * 1e-16
|
||||
|
||||
distances = (1.0 -mask) * torch.sqrt(distances)
|
||||
|
||||
return distances
|
||||
|
||||
# def _pairwise_distances(embeddings, squared=False):
|
||||
# embeddings = F.normalize(embeddings, p=2, dim=1)
|
||||
# dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
# cosine_distance = 1 - dot_product
|
||||
# return cosine_distance
|
||||
|
||||
|
||||
|
||||
def _get_triplet_mask(labels):
|
||||
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
||||
A triplet (i, j, k) is valid if:
|
||||
- i, j, k are distinct
|
||||
- labels[i] == labels[j] and labels[i] != labels[k]
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
"""
|
||||
# Check that i, j and k are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
|
||||
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||
|
||||
|
||||
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
i_equal_j = label_equal.unsqueeze(2)
|
||||
i_equal_k = label_equal.unsqueeze(1)
|
||||
|
||||
valid_labels = ~i_equal_k & i_equal_j
|
||||
|
||||
return valid_labels & distinct_indices
|
||||
|
||||
|
||||
def _get_anchor_positive_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check that i and j are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
|
||||
# Check if labels[i] == labels[j]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
|
||||
return labels_equal & indices_not_equal
|
||||
|
||||
|
||||
def _get_anchor_negative_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check if labels[i] != labels[k]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
|
||||
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
||||
|
||||
|
||||
# Cell
|
||||
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist + margin
|
||||
tl = F.relu(tl)
|
||||
triplet_loss = tl.mean()
|
||||
|
||||
return triplet_loss
|
||||
|
||||
# Cell
|
||||
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||
|
||||
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
||||
|
||||
def batch_all_soft_margin_triplet_loss(labels, embeddings, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist
|
||||
|
||||
# Apply exponential and log
|
||||
triplet_loss = torch.log(1 + torch.exp(triplet_loss))
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
# triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
||||
|
||||
|
||||
|
||||
def batch_hard_soft_margin_triplet_loss(labels, embeddings, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist
|
||||
# Apply exponential and log
|
||||
triplet_loss = torch.log(1 + torch.exp(tl))
|
||||
|
||||
triplet_loss = triplet_loss.mean()
|
||||
|
||||
return triplet_loss
|
|
@ -0,0 +1,574 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import accuracy_score
|
||||
from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
CORRUPT=0.1
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
# %%
|
||||
EVAL_FILE="top1_curves/character_output.txt"
|
||||
with open(EVAL_FILE, "w") as f:
|
||||
pass
|
||||
|
||||
EVAL_FILE_KNN="top1_curves/character_knn.txt"
|
||||
with open(EVAL_FILE_KNN, "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
# def sample_from_df(df, sample_size_per_class=5):
|
||||
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
|
||||
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
# .reset_index(drop=True))
|
||||
#
|
||||
# return sampled_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
# %%
|
||||
# evaluation
|
||||
def run_evaluation_logit(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
pred_labels = []
|
||||
for batch in batches:
|
||||
# Inference in batches
|
||||
inputs, attn_mask = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
with torch.no_grad():
|
||||
_, logits = model(inputs, attn_mask)
|
||||
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||
pred_labels.extend(predicted_class_ids)
|
||||
|
||||
|
||||
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||
|
||||
# %%
|
||||
labels = [label2id[element] for element in y_test]
|
||||
with open(EVAL_FILE, "a") as f:
|
||||
# only compute top-1
|
||||
accuracy = accuracy_score(labels, pred_labels)
|
||||
print(f'{accuracy}', file=f)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def run_evaluation_knn(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(eval_train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs, attn_mask = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
outputs = model(inputs, attn_mask)
|
||||
output_slice = outputs[:,0,:]
|
||||
output_slice = output_slice.detach().cpu().numpy()
|
||||
embedding_list.append(output_slice)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs, attn_mask = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
outputs = model(inputs, attn_mask)
|
||||
output_slice = outputs[:,0,:]
|
||||
output_slice = output_slice.detach().cpu().numpy()
|
||||
embedding_list.append(output_slice)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels)
|
||||
|
||||
|
||||
with open(EVAL_FILE_KNN, "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [eval_labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
# %%
|
||||
class CharacterTransformer(nn.Module):
|
||||
def __init__(self, num_chars, d_model=128, nhead=4, num_encoder_layers=2):
|
||||
super(CharacterTransformer, self).__init__()
|
||||
self.char_embedding = nn.Embedding(num_chars, d_model)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def forward(self, input, attention_mask):
|
||||
# input: (batch_size, seq_len)
|
||||
embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model)
|
||||
# embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model)
|
||||
output = self.transformer_encoder(embeddings, src_key_padding_mask=attention_mask)
|
||||
# output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model)
|
||||
return output
|
||||
|
||||
# %%
|
||||
|
||||
class ASCIITokenizer:
|
||||
def __init__(self, pad_token='\0'):
|
||||
# Initialize the tokenizer with ASCII characters.
|
||||
# ASCII characters range from 0 to 127.
|
||||
self.char_to_id = {chr(i): i for i in range(128)}
|
||||
self.id_to_char = {i: chr(i) for i in range(128)}
|
||||
self.pad_token = pad_token
|
||||
|
||||
|
||||
def encode(self, text_list):
|
||||
"""Encode a text string into a list of ASCII IDs and generate attention masks."""
|
||||
output_list = []
|
||||
max_length = 0
|
||||
# First pass to find the maximum length and encode the texts
|
||||
for text in text_list:
|
||||
text = self.pad_token + text # Prepend pad_token to each text
|
||||
output = [self.char_to_id.get(char, self.pad_token) for char in text]
|
||||
output_list.append(output)
|
||||
if len(output) > max_length:
|
||||
max_length = len(output)
|
||||
|
||||
# Second pass to pad the sequences to the maximum length and create masks
|
||||
padded_list = []
|
||||
attention_masks = []
|
||||
for output in output_list:
|
||||
# first element is not masked
|
||||
attention_mask = [0] + [0] * (len(output) - 1) + [1] * (max_length - len(output)) # 1s for real tokens, 0s for padding
|
||||
output = self.pad(output, max_length)
|
||||
padded_list.append(output)
|
||||
attention_masks.append(attention_mask)
|
||||
|
||||
return torch.tensor(padded_list, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.bool)
|
||||
|
||||
def decode(self, ids_list):
|
||||
"""Decode a list of ASCII IDs back into a text string."""
|
||||
output_list = []
|
||||
for ids in ids_list:
|
||||
output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char)
|
||||
output_list.append(output)
|
||||
return output_list
|
||||
|
||||
def pad(self, output, max_length):
|
||||
"""Pad the output list with ASCII ID for space or another padding character to the maximum length."""
|
||||
return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output))
|
||||
# %%
|
||||
tokenizer = ASCIITokenizer()
|
||||
# # Example text
|
||||
# text = ["Hello, world! This is cool", "Hello, world!"]
|
||||
# # Encode the text
|
||||
# encoded = tokenizer.encode(text)
|
||||
# print("Encoded:", encoded)
|
||||
#
|
||||
# # Decode the encoded IDs
|
||||
# decoded = tokenizer.decode(encoded.numpy())
|
||||
# print("Decoded:", decoded)
|
||||
|
||||
# %%
|
||||
# Example usage
|
||||
bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids, attention_mask)
|
||||
cls_embeddings = outputs[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 32 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 5000
|
||||
|
||||
# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# num_warmup_steps=100
|
||||
# total_steps = epochs * (1126/64)
|
||||
# scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
|
||||
|
||||
|
||||
# %%
|
||||
state_dict = torch.load('./checkpoint/pretrained_character_bert_final.pt')
|
||||
state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()}
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
if epoch % 1 == 0:
|
||||
augmented_df = augment_data(df)
|
||||
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs, attn_mask = tokenizer.encode(x)
|
||||
inputs = inputs.to(DEVICE)
|
||||
attn_mask = attn_mask.to(DEVICE)
|
||||
cls, logits = model(inputs, attn_mask)
|
||||
# labels = y
|
||||
# labels = [label2id[element] for element in labels]
|
||||
# labels = torch.tensor(labels).to(DEVICE)
|
||||
|
||||
# loss = F.cross_entropy(logits, labels)
|
||||
|
||||
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
# loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
loss.backward()
|
||||
# scheduler.step()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# del x, y, outputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
epoch_loss = total_loss/batch_number
|
||||
print(f'epoch loss: {epoch_loss}')
|
||||
if (epoch % 10 == 0):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# run_evaluation_logit(model=model, tokenizer=tokenizer)
|
||||
run_evaluation_knn(model=model.bert, tokenizer=tokenizer)
|
||||
# run evaluation on test data
|
||||
model.train()
|
||||
|
||||
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if (epoch % 100 == 0) and (epoch > 100):
|
||||
torch.save(model.state_dict(), './checkpoint/pretrained_character_bert.pt')
|
||||
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/pretrained_character_bert_final.pt')
|
||||
# %%
|
|
@ -0,0 +1,4 @@
|
|||
__pycache__
|
||||
checkpoint
|
||||
results
|
||||
top1_curves
|
|
@ -0,0 +1,131 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
# state_dict = torch.load('./checkpoint/classification.pt')
|
||||
state_dict = torch.load('./checkpoint/baseline.pt')
|
||||
# params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
||||
with open("results/predictions.txt", "w") as f:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
print(f'{a}, {b[0]}', file=f)
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
CORRUPT=0.1
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
# %%
|
||||
with open("top1_curves/baseline_output.txt", "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
# def sample_from_df(df, sample_size_per_class=5):
|
||||
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
|
||||
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
# .reset_index(drop=True))
|
||||
#
|
||||
# return sampled_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
# %%
|
||||
# evaluation
|
||||
def run_evaluation(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(eval_train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, eval_labels)
|
||||
|
||||
|
||||
with open("top1_curves/baseline_output.txt", "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [eval_labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
if epoch % 1 == 0:
|
||||
augmented_df = augment_data(df)
|
||||
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# del x, y, outputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
epoch_loss = total_loss/batch_number
|
||||
# scheduler.step(epoch_loss)
|
||||
|
||||
# run evaluation on test data
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation(model=model, tokenizer=tokenizer)
|
||||
|
||||
model.train()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {epoch_loss}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch == 125:
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
# torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
|
@ -0,0 +1,424 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import (
|
||||
batch_all_triplet_loss,
|
||||
batch_hard_triplet_loss,
|
||||
batch_all_soft_margin_triplet_loss,
|
||||
batch_hard_soft_margin_triplet_loss)
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:3') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
|
||||
# %%
|
||||
EVAL_FILE="top1_curves/batch_all_output.txt"
|
||||
with open(EVAL_FILE, "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# evaluation
|
||||
def run_evaluation(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels)
|
||||
|
||||
|
||||
with open(EVAL_FILE, "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
total_cross = 0.0
|
||||
total_triplet = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
|
||||
# if epoch < epochs / 2:
|
||||
loss, _ = batch_all_soft_margin_triplet_loss(y, cls, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
# else:
|
||||
# triplet_loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
# total_cross += class_loss.detach().item()
|
||||
# total_triplet += triplet_loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# run evaluation on test data
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation(model=model.bert, tokenizer=tokenizer)
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
# print(f'epoch loss: {total_loss/batch_number}, cross loss: {total_cross/batch_number}, triplet loss: {total_triplet/batch_number}')
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/batch_all.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/batch_all.pt')
|
||||
# %%
|
|
@ -0,0 +1,561 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import accuracy_score
|
||||
from transformers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
CORRUPT=0.1
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
# %%
|
||||
EVAL_FILE="top1_curves/character_output.txt"
|
||||
with open(EVAL_FILE, "w") as f:
|
||||
pass
|
||||
|
||||
EVAL_FILE_KNN="top1_curves/character_knn.txt"
|
||||
with open(EVAL_FILE_KNN, "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
# def sample_from_df(df, sample_size_per_class=5):
|
||||
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
|
||||
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
# .reset_index(drop=True))
|
||||
#
|
||||
# return sampled_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
# %%
|
||||
# evaluation
|
||||
def run_evaluation_logit(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
pred_labels = []
|
||||
for batch in batches:
|
||||
# Inference in batches
|
||||
inputs = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
with torch.no_grad():
|
||||
_, logits = model(inputs)
|
||||
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||
pred_labels.extend(predicted_class_ids)
|
||||
|
||||
|
||||
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||
|
||||
# %%
|
||||
labels = [label2id[element] for element in y_test]
|
||||
with open(EVAL_FILE, "a") as f:
|
||||
# only compute top-1
|
||||
accuracy = accuracy_score(labels, pred_labels)
|
||||
print(f'{accuracy}', file=f)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def run_evaluation_knn(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(eval_train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
outputs = model(inputs)
|
||||
output_slice = outputs[:,0,:]
|
||||
output_slice = output_slice.detach().cpu().numpy()
|
||||
embedding_list.append(output_slice)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer.encode(batch)
|
||||
inputs = inputs.to(DEVICE)
|
||||
outputs = model(inputs)
|
||||
output_slice = outputs[:,0,:]
|
||||
output_slice = output_slice.detach().cpu().numpy()
|
||||
embedding_list.append(output_slice)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, eval_labels)
|
||||
|
||||
|
||||
with open(EVAL_FILE_KNN, "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [eval_labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
# %%
|
||||
class CharacterTransformer(nn.Module):
|
||||
def __init__(self, num_chars, d_model=512, nhead=8, num_encoder_layers=6):
|
||||
super(CharacterTransformer, self).__init__()
|
||||
self.char_embedding = nn.Embedding(num_chars, d_model)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def forward(self, input):
|
||||
# input: (batch_size, seq_len)
|
||||
embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model)
|
||||
# embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model)
|
||||
output = self.transformer_encoder(embeddings)
|
||||
# output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model)
|
||||
return output
|
||||
|
||||
class ASCIITokenizer:
|
||||
def __init__(self, pad_token='\0'):
|
||||
# Initialize the tokenizer with ASCII characters.
|
||||
# ASCII characters range from 0 to 127.
|
||||
self.char_to_id = {chr(i): i for i in range(128)}
|
||||
self.id_to_char = {i: chr(i) for i in range(128)}
|
||||
self.pad_token = pad_token
|
||||
|
||||
def encode(self, text_list):
|
||||
"""Encode a text string into a list of ASCII IDs."""
|
||||
output_list = []
|
||||
max_length = 0
|
||||
for text in text_list:
|
||||
text = self.pad_token + text
|
||||
output = [self.char_to_id.get(char, None) for char in text if char in self.char_to_id]
|
||||
output_list.append(output)
|
||||
if len(output) > max_length:
|
||||
max_length = len(output)
|
||||
padded_list = [self.pad(output, max_length) for output in output_list]
|
||||
# Convert the list of lists into a tensor
|
||||
return torch.tensor(padded_list, dtype=torch.long)
|
||||
|
||||
def decode(self, ids_list):
|
||||
"""Decode a list of ASCII IDs back into a text string."""
|
||||
output_list = []
|
||||
for ids in ids_list:
|
||||
output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char)
|
||||
output_list.append(output)
|
||||
return output_list
|
||||
|
||||
def pad(self, output, max_length):
|
||||
"""Pad the output list with ASCII ID for space or another padding character to the maximum length."""
|
||||
return output + [self.char_to_id.get(self.pad_token)] * (max_length - len(output))
|
||||
# %%
|
||||
tokenizer = ASCIITokenizer()
|
||||
# # Example text
|
||||
# text = ["Hello, world! This is cool", "Hello, world!"]
|
||||
# # Encode the text
|
||||
# encoded = tokenizer.encode(text)
|
||||
# print("Encoded:", encoded)
|
||||
#
|
||||
# # Decode the encoded IDs
|
||||
# decoded = tokenizer.decode(encoded.numpy())
|
||||
# print("Decoded:", decoded)
|
||||
|
||||
# %%
|
||||
# Example usage
|
||||
bert_model = CharacterTransformer(num_chars=128) # Assuming ASCII characters
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.char_embedding.embedding_dim, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids)
|
||||
cls_embeddings = outputs[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 5000
|
||||
|
||||
# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
num_warmup_steps=100
|
||||
total_steps = epochs * (1126/64)
|
||||
scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, total_steps, lr_end=5e-6)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
|
||||
|
||||
|
||||
# %%
|
||||
# state_dict = torch.load('./checkpoint/character_bert.pt')
|
||||
# state_dict = {key.replace('_orig_mod.', ''): value for key, value in state_dict.items()}
|
||||
# model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
if epoch % 1 == 0:
|
||||
augmented_df = augment_data(df)
|
||||
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer.encode(x)
|
||||
inputs = inputs.to(DEVICE)
|
||||
cls, logits = model(inputs)
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
|
||||
|
||||
# for training less than half the time, train on easy
|
||||
# y = torch.tensor(y).to(DEVICE)
|
||||
# if epoch < epochs / 2:
|
||||
# loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# # for training after half the time, train on hard
|
||||
# else:
|
||||
# loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# del x, y, outputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
epoch_loss = total_loss/batch_number
|
||||
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {epoch_loss}')
|
||||
if (epoch % 10 == 0):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation_logit(model=model, tokenizer=tokenizer)
|
||||
run_evaluation_knn(model=model.bert, tokenizer=tokenizer)
|
||||
# run evaluation on test data
|
||||
model.train()
|
||||
|
||||
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if (epoch % 100 == 0) and (epoch > 100):
|
||||
torch.save(model.state_dict(), './checkpoint/character_bert.pt')
|
||||
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/character_bert_final.pt')
|
||||
# %%
|
|
@ -0,0 +1,124 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
state_dict = torch.load('./checkpoint/classification.pt')
|
||||
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(params_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,258 @@
|
|||
# %%
|
||||
|
||||
# from datasets import load_from_disk
|
||||
import os
|
||||
import glob
|
||||
|
||||
os.environ['NCCL_P2P_DISABLE'] = '1'
|
||||
os.environ['NCCL_IB_DISABLE'] = '1'
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModel,
|
||||
DataCollatorWithPadding,
|
||||
)
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# import matplotlib.pyplot as plt
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
|
||||
BATCH_SIZE = 256
|
||||
|
||||
# %%
|
||||
# construct the target id list
|
||||
# data_path = '../../../esAppMod_data_import/train.csv'
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = train_df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
|
||||
# %%
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
|
||||
# introduce pre-processing functions
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# Substitute digits with '#'
|
||||
# text = re.sub(r'\d+', '#', text)
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
|
||||
# outputs a list of dictionaries
|
||||
# processes dataframe into lists of dictionaries
|
||||
# each element maps input to output
|
||||
# input: tag_description
|
||||
# output: class label
|
||||
def process_df_to_dict(df):
|
||||
output_list = []
|
||||
for _, row in df.iterrows():
|
||||
desc = row['mention']
|
||||
desc = preprocess_text(desc)
|
||||
index = row['entity_id']
|
||||
element = {
|
||||
'text' : desc,
|
||||
'label': label2id[index], # ensure labels starts from 0
|
||||
}
|
||||
output_list.append(element)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
def create_dataset():
|
||||
# train
|
||||
data_path = '../esAppMod_data_import/test.csv'
|
||||
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
|
||||
|
||||
# combined_data = DatasetDict({
|
||||
# 'train': Dataset.from_list(process_df_to_dict(train_df)),
|
||||
# })
|
||||
return Dataset.from_list(process_df_to_dict(test_df))
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
def test():
|
||||
|
||||
test_dataset = create_dataset()
|
||||
|
||||
# prepare tokenizer
|
||||
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'distilbert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small'
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
# tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
# %%
|
||||
# compute max token length
|
||||
max_length = 0
|
||||
for sample in test_dataset['text']:
|
||||
# Tokenize the sample and get the length
|
||||
input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"]
|
||||
length = len(input_ids)
|
||||
|
||||
# Update max_length if this sample is longer
|
||||
if length > max_length:
|
||||
max_length = length
|
||||
|
||||
print(max_length)
|
||||
|
||||
# %%
|
||||
|
||||
max_length = 128
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
def preprocess_function(example):
|
||||
input = example['text']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
# max_length=max_length,
|
||||
# padding='max_length'
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
# map maps function to each "row" in the dataset
|
||||
# aka the data in the immediate nesting
|
||||
datasets = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=8,
|
||||
remove_columns="text",
|
||||
)
|
||||
|
||||
|
||||
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
state_dict = torch.load('./checkpoint/classification.pt')
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
pred_labels = []
|
||||
actual_labels = []
|
||||
|
||||
|
||||
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
|
||||
for batch in tqdm(dataloader):
|
||||
# Inference in batches
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
# save labels too
|
||||
actual_labels.extend(batch['labels'])
|
||||
|
||||
|
||||
# Move to GPU if available
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
cls, logits = model(
|
||||
input_ids,
|
||||
attention_mask)
|
||||
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||
pred_labels.extend(predicted_class_ids)
|
||||
|
||||
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||
|
||||
|
||||
# %%
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
|
||||
y_true = actual_labels
|
||||
y_pred = pred_labels
|
||||
|
||||
# Compute metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
average_parameter = 'weighted'
|
||||
zero_division_parameter = 0
|
||||
f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
|
||||
with open("results/output.txt", "a") as f:
|
||||
|
||||
print('*' * 80, file=f)
|
||||
# Print the results
|
||||
print(f'Accuracy: {accuracy:.5f}', file=f)
|
||||
print(f'F1 Score: {f1:.5f}', file=f)
|
||||
print(f'Precision: {precision:.5f}', file=f)
|
||||
print(f'Recall: {recall:.5f}', file=f)
|
||||
|
||||
# export result
|
||||
label_list = [id2label[id] for id in pred_labels]
|
||||
df = pd.DataFrame({
|
||||
'class_prediction': pd.Series(label_list)
|
||||
})
|
||||
|
||||
# we can save the t5 generation output here
|
||||
df.to_csv(f"results/classify.csv", index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# reset file before writing to it
|
||||
with open("results/output.txt", "w") as f:
|
||||
print('', file=f)
|
||||
test()
|
|
@ -0,0 +1,315 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
# y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
# if epoch < epochs / 2:
|
||||
# triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# # for training after half the time, train on hard
|
||||
# else:
|
||||
# triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss # + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
# %%
|
|
@ -0,0 +1,277 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# all augmentations disabled
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, outputs, cls, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch % 5 == 0:
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
|
@ -0,0 +1,315 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=2
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# Update optimizer's learning rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
# %%
|
|
@ -0,0 +1,124 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
state_dict = torch.load('./checkpoint/hybrid.pt')
|
||||
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(params_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,433 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import (
|
||||
batch_all_triplet_loss,
|
||||
batch_hard_triplet_loss,
|
||||
batch_all_soft_margin_triplet_loss,
|
||||
batch_hard_soft_margin_triplet_loss)
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
LEARNING_RATE=1e-4
|
||||
DEVICE = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
|
||||
# %%
|
||||
EVAL_FILE="top1_curves/hybrid_output.txt"
|
||||
with open(EVAL_FILE, "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', '_', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# evaluation
|
||||
def run_evaluation(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels)
|
||||
|
||||
|
||||
with open(EVAL_FILE, "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
total_cross = 0.0
|
||||
total_triplet = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
if epoch % 10 == 0:
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
# triplet_loss, _ = batch_all_soft_margin_triplet_loss(y, cls, squared=False)
|
||||
# loss = class_loss + triplet_loss
|
||||
# loss,_ = batch_all_soft_margin_triplet_loss(y, cls, squared=False)
|
||||
loss = class_loss
|
||||
# for training after half the time, train on hard
|
||||
# else:
|
||||
# triplet_loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False)
|
||||
# loss = triplet_loss
|
||||
else:
|
||||
loss = batch_hard_soft_margin_triplet_loss(y, cls, squared=False)
|
||||
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
# total_cross += class_loss.detach().item()
|
||||
# total_triplet += triplet_loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# run evaluation on test data
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation(model=model.bert, tokenizer=tokenizer)
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
# print(f'epoch loss: {total_loss/batch_number}, cross loss: {total_cross/batch_number}, triplet loss: {total_triplet/batch_number}')
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
# if epoch % 5 == 0:
|
||||
# # torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
# torch.save(model.state_dict(), './checkpoint/hybrid.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
# torch.save(model.state_dict(), './checkpoint/hybrid.pt')
|
||||
# %%
|
|
@ -0,0 +1,288 @@
|
|||
# stardard functionalities for computing triplet loss, borrow code from
|
||||
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
def _pairwise_distances(embeddings, squared=False):
|
||||
"""Compute the 2D matrix of distances between all the embeddings.
|
||||
Args:
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||
"""
|
||||
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
|
||||
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||
# shape (batch_size,)
|
||||
square_norm = torch.diag(dot_product)
|
||||
|
||||
# Compute the pairwise distance matrix as we have:
|
||||
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||
# shape (batch_size, batch_size)
|
||||
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||
|
||||
# Because of computation errors, some distances might be negative so we put everything >= 0.0
|
||||
distances[distances < 0] = 0
|
||||
|
||||
if not squared:
|
||||
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||
# we need to add a small epsilon where distances == 0.0
|
||||
mask = distances.eq(0).float()
|
||||
distances = distances + mask * 1e-16
|
||||
|
||||
distances = (1.0 -mask) * torch.sqrt(distances)
|
||||
|
||||
return distances
|
||||
|
||||
# def _pairwise_distances(embeddings, squared=False):
|
||||
# embeddings = F.normalize(embeddings, p=2, dim=1)
|
||||
# dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
# cosine_distance = 1 - dot_product
|
||||
# return cosine_distance
|
||||
|
||||
|
||||
|
||||
def _get_triplet_mask(labels):
|
||||
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
||||
A triplet (i, j, k) is valid if:
|
||||
- i, j, k are distinct
|
||||
- labels[i] == labels[j] and labels[i] != labels[k]
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
"""
|
||||
# Check that i, j and k are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
|
||||
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||
|
||||
|
||||
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
i_equal_j = label_equal.unsqueeze(2)
|
||||
i_equal_k = label_equal.unsqueeze(1)
|
||||
|
||||
valid_labels = ~i_equal_k & i_equal_j
|
||||
|
||||
return valid_labels & distinct_indices
|
||||
|
||||
|
||||
def _get_anchor_positive_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check that i and j are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
|
||||
# Check if labels[i] == labels[j]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
|
||||
return labels_equal & indices_not_equal
|
||||
|
||||
|
||||
def _get_anchor_negative_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check if labels[i] != labels[k]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
|
||||
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
||||
|
||||
|
||||
# Cell
|
||||
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist + margin
|
||||
tl = F.relu(tl)
|
||||
triplet_loss = tl.mean()
|
||||
|
||||
return triplet_loss
|
||||
|
||||
# Cell
|
||||
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||
|
||||
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
||||
|
||||
def batch_all_soft_margin_triplet_loss(labels, embeddings, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist
|
||||
|
||||
# Apply exponential and log
|
||||
triplet_loss = torch.log(1 + torch.exp(triplet_loss))
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
# triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
||||
|
||||
|
||||
|
||||
def batch_hard_soft_margin_triplet_loss(labels, embeddings, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist
|
||||
# Apply exponential and log
|
||||
triplet_loss = torch.log(1 + torch.exp(tl))
|
||||
|
||||
triplet_loss = triplet_loss.mean()
|
||||
|
||||
return triplet_loss
|
|
@ -0,0 +1,4 @@
|
|||
__pycache__
|
||||
checkpoint
|
||||
results
|
||||
top1_curves
|
|
@ -0,0 +1,132 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
# state_dict = torch.load('./checkpoint/classification.pt')
|
||||
state_dict = torch.load('./checkpoint/baseline.pt')
|
||||
# params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
with open("results/predictions.txt", "w") as f:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
print(f'{a}, {b[0]}', file=f)
|
||||
|
||||
|
||||
# %%
|
|
@ -0,0 +1,382 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
# %%
|
||||
EVAL_FILE="top1_curves/baseline_output.txt"
|
||||
with open(EVAL_FILE, "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# all augmentations disabled
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
# evaluation
|
||||
def run_evaluation(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, labels)
|
||||
|
||||
|
||||
with open(EVAL_FILE, "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# run evaluation on test data
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation(model=model, tokenizer=tokenizer)
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch == 175:
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
# torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
|
@ -0,0 +1,124 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
state_dict = torch.load('./checkpoint/classification.pt')
|
||||
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(params_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,258 @@
|
|||
# %%
|
||||
|
||||
# from datasets import load_from_disk
|
||||
import os
|
||||
import glob
|
||||
|
||||
os.environ['NCCL_P2P_DISABLE'] = '1'
|
||||
os.environ['NCCL_IB_DISABLE'] = '1'
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModel,
|
||||
DataCollatorWithPadding,
|
||||
)
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# import matplotlib.pyplot as plt
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
|
||||
BATCH_SIZE = 256
|
||||
|
||||
# %%
|
||||
# construct the target id list
|
||||
# data_path = '../../../esAppMod_data_import/train.csv'
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = train_df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
|
||||
# %%
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
|
||||
# introduce pre-processing functions
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# Substitute digits with '#'
|
||||
# text = re.sub(r'\d+', '#', text)
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
|
||||
# outputs a list of dictionaries
|
||||
# processes dataframe into lists of dictionaries
|
||||
# each element maps input to output
|
||||
# input: tag_description
|
||||
# output: class label
|
||||
def process_df_to_dict(df):
|
||||
output_list = []
|
||||
for _, row in df.iterrows():
|
||||
desc = row['mention']
|
||||
desc = preprocess_text(desc)
|
||||
index = row['entity_id']
|
||||
element = {
|
||||
'text' : desc,
|
||||
'label': label2id[index], # ensure labels starts from 0
|
||||
}
|
||||
output_list.append(element)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
def create_dataset():
|
||||
# train
|
||||
data_path = '../esAppMod_data_import/test.csv'
|
||||
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
|
||||
|
||||
# combined_data = DatasetDict({
|
||||
# 'train': Dataset.from_list(process_df_to_dict(train_df)),
|
||||
# })
|
||||
return Dataset.from_list(process_df_to_dict(test_df))
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
def test():
|
||||
|
||||
test_dataset = create_dataset()
|
||||
|
||||
# prepare tokenizer
|
||||
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'distilbert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small'
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, return_tensors="pt", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
# tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
# %%
|
||||
# compute max token length
|
||||
max_length = 0
|
||||
for sample in test_dataset['text']:
|
||||
# Tokenize the sample and get the length
|
||||
input_ids = tokenizer(sample, truncation=False, add_special_tokens=True)["input_ids"]
|
||||
length = len(input_ids)
|
||||
|
||||
# Update max_length if this sample is longer
|
||||
if length > max_length:
|
||||
max_length = length
|
||||
|
||||
print(max_length)
|
||||
|
||||
# %%
|
||||
|
||||
max_length = 128
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
def preprocess_function(example):
|
||||
input = example['text']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
# max_length=max_length,
|
||||
# padding='max_length'
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
# map maps function to each "row" in the dataset
|
||||
# aka the data in the immediate nesting
|
||||
datasets = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=8,
|
||||
remove_columns="text",
|
||||
)
|
||||
|
||||
|
||||
datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
state_dict = torch.load('./checkpoint/classification.pt')
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
pred_labels = []
|
||||
actual_labels = []
|
||||
|
||||
|
||||
dataloader = DataLoader(datasets, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator)
|
||||
for batch in tqdm(dataloader):
|
||||
# Inference in batches
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
# save labels too
|
||||
actual_labels.extend(batch['labels'])
|
||||
|
||||
|
||||
# Move to GPU if available
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Perform inference
|
||||
with torch.no_grad():
|
||||
cls, logits = model(
|
||||
input_ids,
|
||||
attention_mask)
|
||||
predicted_class_ids = logits.argmax(dim=1).to("cpu")
|
||||
pred_labels.extend(predicted_class_ids)
|
||||
|
||||
pred_labels = [tensor.item() for tensor in pred_labels]
|
||||
|
||||
|
||||
# %%
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
|
||||
y_true = actual_labels
|
||||
y_pred = pred_labels
|
||||
|
||||
# Compute metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
average_parameter = 'weighted'
|
||||
zero_division_parameter = 0
|
||||
f1 = f1_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
precision = precision_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
recall = recall_score(y_true, y_pred, average=average_parameter, zero_division=zero_division_parameter)
|
||||
|
||||
with open("results/output.txt", "a") as f:
|
||||
|
||||
print('*' * 80, file=f)
|
||||
# Print the results
|
||||
print(f'Accuracy: {accuracy:.5f}', file=f)
|
||||
print(f'F1 Score: {f1:.5f}', file=f)
|
||||
print(f'Precision: {precision:.5f}', file=f)
|
||||
print(f'Recall: {recall:.5f}', file=f)
|
||||
|
||||
# export result
|
||||
label_list = [id2label[id] for id in pred_labels]
|
||||
df = pd.DataFrame({
|
||||
'class_prediction': pd.Series(label_list)
|
||||
})
|
||||
|
||||
# we can save the t5 generation output here
|
||||
df.to_csv(f"results/classify.csv", index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# reset file before writing to it
|
||||
with open("results/output.txt", "w") as f:
|
||||
print('', file=f)
|
||||
test()
|
|
@ -0,0 +1,316 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# disable augmentations
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
# y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
# if epoch < epochs / 2:
|
||||
# triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# # for training after half the time, train on hard
|
||||
# else:
|
||||
# triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss # + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
# %%
|
|
@ -0,0 +1,277 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# all augmentations disabled
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, outputs, cls, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
if epoch % 5 == 0:
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
|
@ -0,0 +1,315 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=2
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# add shuffled strings
|
||||
processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
for desc in processed_descs:
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# short sequence amplifier
|
||||
# short sequences are rare, and we must compensate by including more examples
|
||||
# also, short sequence don't usually get affected by shuffle
|
||||
words = parent_desc.split()
|
||||
word_count = len(words)
|
||||
if word_count <= 2:
|
||||
for _ in range(AMPLIFY_FACTOR):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# Update optimizer's learning rate
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/classification.pt')
|
||||
# %%
|
|
@ -0,0 +1,124 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import gc
|
||||
|
||||
# %%
|
||||
# Step 2: Load the state dictionary
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
# state_dict = torch.load('./checkpoint/siamese.pt')
|
||||
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
||||
state_dict = torch.load('./checkpoint/hybrid.pt')
|
||||
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
||||
|
||||
# %%
|
||||
# Step 3: Apply the state dictionary to the model
|
||||
model.load_state_dict(params_dict)
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
# %%
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
||||
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
||||
train_entities = [preprocess_text(element) for element in train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
# %%
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# %%
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
|
||||
# %%
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
||||
n_neighbors = [1, 3, 5, 10]
|
||||
|
||||
|
||||
with open("results/output.txt", "w") as f:
|
||||
for n in n_neighbors:
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
||||
print(np.min(distances), np.max(distances), file=f)
|
||||
|
||||
# %%
|
|
@ -0,0 +1,315 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# %%
|
||||
SHUFFLES=0
|
||||
AMPLIFY_FACTOR=0
|
||||
LEARNING_RATE=1e-5
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add corrupted strings
|
||||
# desc = corrupt_string(parent_desc, corruption_probability=0.01)
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # add example with stripped non-alphanumerics
|
||||
# desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 16 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
bert_model = AutoModel.from_pretrained(MODEL_NAME)
|
||||
|
||||
class BertForClassificationAndTriplet(nn.Module):
|
||||
def __init__(self, bert_model, num_classes):
|
||||
super().__init__()
|
||||
self.bert = bert_model
|
||||
self.classifier = nn.Linear(bert_model.config.hidden_size, num_classes)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
||||
cls_embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
|
||||
logits = self.classifier(cls_embeddings)
|
||||
return cls_embeddings, logits
|
||||
|
||||
model = BertForClassificationAndTriplet(bert_model, num_classes=len(label2id))
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
def linear_decay(epoch, max_epochs, initial_lr, final_lr):
|
||||
""" Calculate the linearly decayed learning rate. """
|
||||
return initial_lr - (epoch / max_epochs) * (initial_lr - final_lr)
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
# lr = linear_decay(epoch, epochs, initial_lr=1e-5, final_lr=5e-6)
|
||||
|
||||
# # Update optimizer's learning rate
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group['lr'] = lr
|
||||
|
||||
augmented_df = augment_data(df)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
cls, logits = model(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask']
|
||||
)
|
||||
# for training less than half the time, train on easy
|
||||
labels = y
|
||||
labels = [label2id[element] for element in labels]
|
||||
labels = torch.tensor(labels).to(DEVICE)
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
|
||||
|
||||
class_loss = F.cross_entropy(logits, labels)
|
||||
|
||||
if epoch < epochs / 2:
|
||||
triplet_loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
triplet_loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
|
||||
loss = class_loss + triplet_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
del x, y, cls, logits, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {total_loss/batch_number}')
|
||||
# print(f"Epoch {epoch+1}: lr={lr}")
|
||||
if epoch % 5 == 0:
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/hybrid.pt')
|
||||
|
||||
|
||||
# torch.save(model.bert.state_dict(), './checkpoint/classification.pt')
|
||||
torch.save(model.state_dict(), './checkpoint/hybrid.pt')
|
||||
# %%
|
|
@ -0,0 +1,193 @@
|
|||
# stardard functionalities for computing triplet loss, borrow code from
|
||||
# https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _pairwise_distances(embeddings, squared=False):
|
||||
"""Compute the 2D matrix of distances between all the embeddings.
|
||||
Args:
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
pairwise_distances: tensor of shape (batch_size, batch_size)
|
||||
"""
|
||||
dot_product = torch.matmul(embeddings, embeddings.t())
|
||||
|
||||
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
|
||||
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
|
||||
# shape (batch_size,)
|
||||
square_norm = torch.diag(dot_product)
|
||||
|
||||
# Compute the pairwise distance matrix as we have:
|
||||
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
|
||||
# shape (batch_size, batch_size)
|
||||
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
|
||||
|
||||
# Apply a lower bound to distances to ensure they are non-negative and avoid tiny negative numbers due to computation errors
|
||||
distances = torch.clamp(distances, min=0.0)
|
||||
|
||||
if not squared:
|
||||
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
|
||||
# we need to add a small epsilon where distances == 0.0
|
||||
epsilon = 1e-16
|
||||
mask = (distances < epsilon).float()
|
||||
distances = distances + mask * epsilon
|
||||
|
||||
distances = (1.0 - mask) * torch.sqrt(distances)
|
||||
|
||||
|
||||
return distances
|
||||
|
||||
def _get_triplet_mask(labels):
|
||||
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
|
||||
A triplet (i, j, k) is valid if:
|
||||
- i, j, k are distinct
|
||||
- labels[i] == labels[j] and labels[i] != labels[k]
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
"""
|
||||
# Check that i, j and k are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
|
||||
# ensures that none of the values use diagonal values (where at least 2 values are the same)
|
||||
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
|
||||
|
||||
|
||||
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
i_equal_j = label_equal.unsqueeze(2)
|
||||
i_equal_k = label_equal.unsqueeze(1)
|
||||
|
||||
# valid triplets are (i,j) sharing same label and
|
||||
# (i,k) having different labels
|
||||
valid_labels = ~i_equal_k & i_equal_j
|
||||
|
||||
return valid_labels & distinct_indices
|
||||
|
||||
|
||||
def _get_anchor_positive_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check that i and j are distinct
|
||||
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
|
||||
indices_not_equal = ~indices_equal
|
||||
|
||||
# Check if labels[i] == labels[j]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
|
||||
return labels_equal & indices_not_equal
|
||||
|
||||
|
||||
def _get_anchor_negative_triplet_mask(labels):
|
||||
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
|
||||
Args:
|
||||
labels: tf.int32 `Tensor` with shape [batch_size]
|
||||
Returns:
|
||||
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
|
||||
"""
|
||||
# Check if labels[i] != labels[k]
|
||||
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
|
||||
|
||||
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
|
||||
|
||||
|
||||
# Cell
|
||||
def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
For each anchor, we get the hardest positive and hardest negative to form a triplet.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
# For each anchor, get the hardest positive
|
||||
# First, we need to get a mask for every valid positive (they should have same label)
|
||||
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
|
||||
|
||||
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
|
||||
anchor_positive_dist = mask_anchor_positive * pairwise_dist
|
||||
|
||||
# shape (batch_size, 1)
|
||||
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
|
||||
|
||||
# For each anchor, get the hardest negative
|
||||
# First, we need to get a mask for every valid negative (they should have different labels)
|
||||
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
|
||||
|
||||
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
|
||||
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
|
||||
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
|
||||
|
||||
# shape (batch_size,)
|
||||
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
|
||||
|
||||
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
|
||||
tl = hardest_positive_dist - hardest_negative_dist + margin
|
||||
tl = F.relu(tl)
|
||||
triplet_loss = tl.mean()
|
||||
|
||||
return triplet_loss
|
||||
|
||||
# Cell
|
||||
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
|
||||
"""Build the triplet loss over a batch of embeddings.
|
||||
We generate all the valid triplets and average the loss over the positive ones.
|
||||
Args:
|
||||
labels: labels of the batch, of size (batch_size,)
|
||||
embeddings: tensor of shape (batch_size, embed_dim)
|
||||
margin: margin for triplet loss
|
||||
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
|
||||
If false, output is the pairwise euclidean distance matrix.
|
||||
Returns:
|
||||
triplet_loss: scalar tensor containing the triplet loss
|
||||
"""
|
||||
# Get the pairwise distance matrix
|
||||
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
|
||||
|
||||
anchor_positive_dist = pairwise_dist.unsqueeze(2)
|
||||
anchor_negative_dist = pairwise_dist.unsqueeze(1)
|
||||
|
||||
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
|
||||
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
|
||||
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
|
||||
# and the 2nd (batch_size, 1, batch_size)
|
||||
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
|
||||
|
||||
|
||||
|
||||
# Put to zero the invalid triplets
|
||||
# (where label(a) != label(p) or label(n) == label(a) or a == p)
|
||||
mask = _get_triplet_mask(labels)
|
||||
triplet_loss = mask.float() * triplet_loss
|
||||
|
||||
# Remove negative losses (i.e. the easy triplets)
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
# Count number of positive triplets (where triplet_loss > 0)
|
||||
valid_triplets = triplet_loss[triplet_loss > 1e-16]
|
||||
num_positive_triplets = valid_triplets.size(0)
|
||||
num_valid_triplets = mask.sum()
|
||||
|
||||
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
|
||||
|
||||
# Get final mean triplet loss over the positive valid triplets
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
|
||||
|
||||
return triplet_loss, fraction_positive_triplets
|
|
@ -0,0 +1,460 @@
|
|||
# %%
|
||||
import torch
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from transformers import BertTokenizer
|
||||
from transformers import AutoModel
|
||||
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import re
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
Set the random seed for reproducibility.
|
||||
"""
|
||||
random.seed(seed) # Python random module
|
||||
np.random.seed(seed) # NumPy random
|
||||
torch.manual_seed(seed) # PyTorch CPU
|
||||
torch.cuda.manual_seed(seed) # PyTorch GPU
|
||||
torch.cuda.manual_seed_all(seed) # If using multiple GPUs
|
||||
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
|
||||
torch.backends.cudnn.benchmark = False # Disable optimization for reproducibility
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
SHUFFLES=1
|
||||
AMPLIFY_FACTOR=1
|
||||
CORRUPT=0.1
|
||||
LEARNING_RATE=1e-5
|
||||
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
||||
# MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
||||
MODEL_NAME = 'helboukkouri/character-bert'
|
||||
|
||||
# %%
|
||||
with open("top1_curves/character_output.txt", "w") as f:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def generate_train_entity_sets(entity_id_mentions, entity_id_name, group_size, anchor=True):
|
||||
# split entity mentions into groups
|
||||
# anchor = False, don't add entity name to each group, simply treat it as a normal mention
|
||||
entity_sets = []
|
||||
if anchor:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
random.shuffle(mentions)
|
||||
positives = [mentions[i:i + group_size] for i in range(0, len(mentions), group_size)]
|
||||
anchor_positive = [([entity_id_name[id]]+p, id) for p in positives]
|
||||
entity_sets.extend(anchor_positive)
|
||||
else:
|
||||
for id, mentions in entity_id_mentions.items():
|
||||
group = list(set([entity_id_name[id]] + mentions))
|
||||
random.shuffle(group)
|
||||
positives = [(mentions[i:i + group_size], id) for i in range(0, len(mentions), group_size)]
|
||||
entity_sets.extend(positives)
|
||||
return entity_sets
|
||||
|
||||
def batchGenerator(data, batch_size):
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
x, y = [], []
|
||||
for t in batch:
|
||||
x.extend(t[0])
|
||||
y.extend([t[1]]*len(t[0]))
|
||||
yield x, y
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
entities = json.load(file)
|
||||
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
train = json.load(file)
|
||||
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||||
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||||
|
||||
# %%
|
||||
###############
|
||||
# alternate data import strategy
|
||||
###################################################
|
||||
# import code
|
||||
# import training file
|
||||
data_path = '../esAppMod_data_import/train.csv'
|
||||
df = pd.read_csv(data_path, skipinitialspace=True)
|
||||
# rather than use pattern, we use the real thing and property
|
||||
entity_ids = df['entity_id'].to_list()
|
||||
target_id_list = sorted(list(set(entity_ids)))
|
||||
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
for idx, val in enumerate(target_id_list):
|
||||
id2label[idx] = val
|
||||
label2id[val] = idx
|
||||
|
||||
df["training_id"] = df["entity_id"].map(label2id)
|
||||
|
||||
# %%
|
||||
##############################################################
|
||||
# augmentation code
|
||||
|
||||
# basic preprocessing
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def generate_random_shuffles(text, n):
|
||||
words = text.split() # Split the input into words
|
||||
shuffled_variations = []
|
||||
|
||||
for _ in range(n):
|
||||
shuffled = words[:] # Copy the word list to avoid in-place modification
|
||||
random.shuffle(shuffled) # Randomly shuffle the words
|
||||
shuffled_variations.append(" ".join(shuffled)) # Join the words back into a string
|
||||
|
||||
return shuffled_variations
|
||||
|
||||
|
||||
def shuffle_text(text, n_shuffles=SHUFFLES):
|
||||
all_processed = []
|
||||
# add the original text
|
||||
all_processed.append(text)
|
||||
|
||||
# Generate random shuffles
|
||||
shuffled_variations = generate_random_shuffles(text, n_shuffles)
|
||||
all_processed.extend(shuffled_variations)
|
||||
|
||||
return all_processed
|
||||
|
||||
def corrupt_word(word):
|
||||
"""Corrupt a single word using random corruption techniques."""
|
||||
if len(word) <= 1: # Skip corruption for single-character words
|
||||
return word
|
||||
|
||||
corruption_type = random.choice(["delete", "swap"])
|
||||
|
||||
if corruption_type == "delete":
|
||||
# Randomly delete a character
|
||||
idx = random.randint(0, len(word) - 1)
|
||||
word = word[:idx] + word[idx + 1:]
|
||||
|
||||
elif corruption_type == "swap":
|
||||
# Swap two adjacent characters
|
||||
if len(word) > 1:
|
||||
idx = random.randint(0, len(word) - 2)
|
||||
word = (word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:])
|
||||
|
||||
|
||||
return word
|
||||
|
||||
def corrupt_string(sentence, corruption_probability=0.01):
|
||||
"""Corrupt each word in the string with a given probability."""
|
||||
words = sentence.split()
|
||||
corrupted_words = [
|
||||
corrupt_word(word) if random.random() < corruption_probability else word
|
||||
for word in words
|
||||
]
|
||||
return " ".join(corrupted_words)
|
||||
|
||||
|
||||
def create_example(index, mention, entity_name):
|
||||
return {'entity_id': index, 'mention': mention, 'entity_name': entity_name}
|
||||
|
||||
# augment whole dataset
|
||||
def augment_data(df):
|
||||
output_list = []
|
||||
|
||||
for idx,row in df.iterrows():
|
||||
index = row['entity_id']
|
||||
entity_name = row['entity_name']
|
||||
parent_desc = row['mention']
|
||||
parent_desc = preprocess_text(parent_desc)
|
||||
|
||||
# add basic example
|
||||
output_list.append(create_example(index, parent_desc, entity_name))
|
||||
|
||||
# # add shuffled strings
|
||||
# processed_descs = shuffle_text(parent_desc, n_shuffles=SHUFFLES)
|
||||
# for desc in processed_descs:
|
||||
# if (desc != parent_desc):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add corrupted strings
|
||||
desc = corrupt_string(parent_desc, corruption_probability=CORRUPT)
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# add example with stripped non-alphanumerics
|
||||
desc = re.sub(r'[^\w\s]', ' ', parent_desc) # Retains only alphanumeric and spaces
|
||||
if (desc != parent_desc):
|
||||
output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
# # short sequence amplifier
|
||||
# # short sequences are rare, and we must compensate by including more examples
|
||||
# # also, short sequence don't usually get affected by shuffle
|
||||
# words = parent_desc.split()
|
||||
# word_count = len(words)
|
||||
# if word_count <= 2:
|
||||
# for _ in range(AMPLIFY_FACTOR):
|
||||
# output_list.append(create_example(index, desc, entity_name))
|
||||
|
||||
new_df = pd.DataFrame(output_list)
|
||||
return new_df
|
||||
|
||||
# def sample_from_df(df, sample_size_per_class=5):
|
||||
# sampled_df = (df.groupby("entity_id")[['entity_id', 'mention', 'entity_name']] # explicit give column names
|
||||
# .apply(lambda x: x.sample(n=min(sample_size_per_class, len(x))))
|
||||
# .reset_index(drop=True))
|
||||
#
|
||||
# return sampled_df
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
def make_entity_id_mentions(df):
|
||||
entity_id_mentions = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
entity_id_mentions[entity_id] = df[df['entity_id']==entity_id]['mention'].to_list()
|
||||
return entity_id_mentions
|
||||
|
||||
def make_entity_id_name(df):
|
||||
entity_id_name = {}
|
||||
entity_id_list = list(set(df['entity_id']))
|
||||
for entity_id in entity_id_list:
|
||||
# entity_id always matches entity_name, so first value would work
|
||||
entity_id_name[entity_id] = df[df['entity_id']==entity_id]['entity_name'].to_list()[0]
|
||||
return entity_id_name
|
||||
|
||||
# %%
|
||||
# evaluation
|
||||
def run_evaluation(model, tokenizer):
|
||||
def preprocess_text(text):
|
||||
# 1. Make all uppercase
|
||||
text = text.lower()
|
||||
|
||||
# standardize spacing
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||||
eval_entities = json.load(file)
|
||||
eval_all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in eval_entities['data'].items()}
|
||||
|
||||
with open('../esAppMod/train.json', 'r') as file:
|
||||
eval_train = json.load(file)
|
||||
eval_train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in eval_train['data'].items()}
|
||||
eval_train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in eval_train['data'].items()}
|
||||
|
||||
with open('../esAppMod/infer.json', 'r') as file:
|
||||
eval_test = json.load(file)
|
||||
x_test = [preprocess_text(d['mention']) for _, d in eval_test['data'].items()]
|
||||
y_test = [d['entity_id'] for _, d in eval_test['data'].items()]
|
||||
eval_train_entities, eval_labels = list(eval_train_entity_id_name.values()), list(eval_train_entity_id_name.keys())
|
||||
eval_train_entities = [preprocess_text(element) for element in eval_train_entities]
|
||||
|
||||
def batch_list(data, batch_size):
|
||||
"""Yield successive n-sized chunks from data."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
yield data[i:i + batch_size]
|
||||
|
||||
batches = batch_list(eval_train_entities, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls = np.concatenate(embedding_list)
|
||||
|
||||
batches = batch_list(x_test, 64)
|
||||
|
||||
embedding_list = []
|
||||
for batch in batches:
|
||||
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
||||
outputs = model(
|
||||
input_ids=inputs['input_ids'].to(DEVICE),
|
||||
attention_mask=inputs['attention_mask'].to(DEVICE)
|
||||
)
|
||||
output = outputs.last_hidden_state[:,0,:]
|
||||
output = output.detach().cpu().numpy()
|
||||
embedding_list.append(output)
|
||||
|
||||
cls_test = np.concatenate(embedding_list)
|
||||
|
||||
knn = KNeighborsClassifier(n_neighbors=1, metric='euclidean').fit(cls, eval_labels)
|
||||
|
||||
|
||||
with open("top1_curves/baseline_output.txt", "a") as f:
|
||||
# only compute top-1
|
||||
distances, indices = knn.kneighbors(cls_test, n_neighbors=1)
|
||||
num = 0
|
||||
for a,b in zip(y_test, indices):
|
||||
b = [eval_labels[i] for i in b]
|
||||
if a in b:
|
||||
num += 1
|
||||
print(f'{num / len(y_test)}', file=f)
|
||||
|
||||
|
||||
# %%
|
||||
class CharacterTransformer(nn.Module):
|
||||
def __init__(self, num_chars, d_model=512, nhead=8, num_encoder_layers=6):
|
||||
super(CharacterTransformer, self).__init__()
|
||||
self.char_embedding = nn.Embedding(num_chars, d_model)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
def forward(self, input):
|
||||
# input: (batch_size, seq_len)
|
||||
embeddings = self.char_embedding(input) # (batch_size, seq_len, d_model)
|
||||
# embeddings = embeddings.permute(1, 0, 2) # (seq_len, batch_size, d_model)
|
||||
output = self.transformer_encoder(embeddings)
|
||||
# output = output.permute(1, 0, 2) # (batch_size, seq_len, d_model)
|
||||
return output
|
||||
|
||||
class ASCIITokenizer:
|
||||
def __init__(self):
|
||||
# Initialize the tokenizer with ASCII characters.
|
||||
# ASCII characters range from 0 to 127.
|
||||
self.char_to_id = {chr(i): i for i in range(128)}
|
||||
self.id_to_char = {i: chr(i) for i in range(128)}
|
||||
|
||||
def encode(self, text_list):
|
||||
"""Encode a text string into a list of ASCII IDs."""
|
||||
output_list = []
|
||||
for text in text_list:
|
||||
output = [self.char_to_id.get(char, None) for char in text if char in self.char_to_id]
|
||||
output_list.append(output)
|
||||
return output_list
|
||||
|
||||
def decode(self, ids_list):
|
||||
"""Decode a list of ASCII IDs back into a text string."""
|
||||
output_list = []
|
||||
for ids in ids_list:
|
||||
output = ''.join(self.id_to_char.get(id, '') for id in ids if id in self.id_to_char)
|
||||
output_list.append(output)
|
||||
return output_list
|
||||
|
||||
# %%
|
||||
tokenizer = ASCIITokenizer()
|
||||
# Example text
|
||||
text = ["Hello, world!", "Hello, world!"]
|
||||
# Encode the text
|
||||
encoded = tokenizer.encode(text)
|
||||
print("Encoded:", encoded)
|
||||
|
||||
# Decode the encoded IDs
|
||||
decoded = tokenizer.decode(encoded)
|
||||
print("Decoded:", decoded)
|
||||
|
||||
# %%
|
||||
# Example usage
|
||||
model = CharacterTransformer(num_chars=128) # Assuming ASCII characters
|
||||
input = torch.randint(0, 128, (10, 50)) # Example input tensor 10 sequences of 50 characters
|
||||
output = model(input)
|
||||
# %%
|
||||
num_sample_per_class = 10 # samples in each group
|
||||
batch_size = 64 # number of groups, effective batch_size for computing triplet loss = batch_size * num_sample_per_class
|
||||
margin = 2
|
||||
epochs = 200
|
||||
|
||||
# model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
# tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
|
||||
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9, cooldown=5, verbose=True)
|
||||
|
||||
|
||||
|
||||
model.to(DEVICE)
|
||||
model.train()
|
||||
|
||||
losses = []
|
||||
|
||||
|
||||
|
||||
for epoch in tqdm(range(epochs)):
|
||||
total_loss = 0.0
|
||||
batch_number = 0
|
||||
|
||||
if epoch % 1 == 0:
|
||||
augmented_df = augment_data(df)
|
||||
# sampled_df = sample_from_df(augmented_df, sample_size_per_class=num_sample_per_class)
|
||||
train_entity_id_mentions = make_entity_id_mentions(augmented_df)
|
||||
train_entity_id_name = make_entity_id_name(augmented_df)
|
||||
|
||||
data = generate_train_entity_sets(train_entity_id_mentions, train_entity_id_name, num_sample_per_class-1, anchor=True)
|
||||
random.shuffle(data)
|
||||
for x,y in batchGenerator(data, batch_size):
|
||||
|
||||
# print(len(x), len(y), end='-->')
|
||||
optimizer.zero_grad()
|
||||
|
||||
inputs = tokenizer(x, padding=True, return_tensors='pt')
|
||||
inputs.to(DEVICE)
|
||||
outputs = model(**inputs)
|
||||
cls = outputs.last_hidden_state[:,0,:]
|
||||
# for training less than half the time, train on easy
|
||||
y = torch.tensor(y).to(DEVICE)
|
||||
if epoch < epochs / 2:
|
||||
loss, _ = batch_all_triplet_loss(y, cls, margin, squared=False)
|
||||
# for training after half the time, train on hard
|
||||
else:
|
||||
loss = batch_hard_triplet_loss(y, cls, margin, squared=False)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.detach().item()
|
||||
batch_number += 1
|
||||
|
||||
# del x, y, outputs, cls, loss
|
||||
# torch.cuda.empty_cache()
|
||||
epoch_loss = total_loss/batch_number
|
||||
# scheduler.step(epoch_loss)
|
||||
|
||||
# run evaluation on test data
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
run_evaluation(model=model, tokenizer=tokenizer)
|
||||
|
||||
model.train()
|
||||
|
||||
# scheduler.step() # Update the learning rate
|
||||
print(f'epoch loss: {epoch_loss}')
|
||||
# print(f"Epoch {epoch+1}: lr={scheduler.get_last_lr()[0]}")
|
||||
# if epoch == 125:
|
||||
# torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
|
||||
|
||||
# torch.save(model.state_dict(), './checkpoint/baseline.pt')
|
||||
# %%
|
Loading…
Reference in New Issue