125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
# %%
|
|
import torch
|
|
import json
|
|
import random
|
|
import numpy as np
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoModel
|
|
from loss import batch_all_triplet_loss, batch_hard_triplet_loss
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
from tqdm import tqdm
|
|
import re
|
|
import gc
|
|
|
|
# %%
|
|
# Step 2: Load the state dictionary
|
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
# MODEL_NAME = 'distilbert-base-cased' #'prajjwal1/bert-small' #'bert-base-cased'
|
|
# MODEL_NAME = 'bert-base-cased' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
|
MODEL_NAME = 'prajjwal1/bert-small' # 'prajjwal1/bert-small' 'bert-base-cased' 'distilbert-base-cased'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
model = AutoModel.from_pretrained(MODEL_NAME)
|
|
|
|
# state_dict = torch.load('./checkpoint/siamese.pt')
|
|
# state_dict = torch.load('./checkpoint/siamese_simple.pt')
|
|
state_dict = torch.load('./checkpoint/classification.pt')
|
|
params_dict = {name.replace('bert.', ''): param for name, param in state_dict.items() if 'classifier' not in name}
|
|
|
|
# %%
|
|
# Step 3: Apply the state dictionary to the model
|
|
model.load_state_dict(params_dict)
|
|
model.to(DEVICE)
|
|
model.eval()
|
|
|
|
# %%
|
|
def preprocess_text(text):
|
|
# 1. Make all uppercase
|
|
text = text.lower()
|
|
|
|
# standardize spacing
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
return text
|
|
|
|
|
|
|
|
# %%
|
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
|
entities = json.load(file)
|
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
|
|
|
with open('../esAppMod/train.json', 'r') as file:
|
|
train = json.load(file)
|
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
|
|
|
|
|
|
|
# %%
|
|
with open('../esAppMod/infer.json', 'r') as file:
|
|
test = json.load(file)
|
|
x_test = [preprocess_text(d['mention']) for _, d in test['data'].items()]
|
|
y_test = [d['entity_id'] for _, d in test['data'].items()]
|
|
train_entities, labels = list(train_entity_id_name.values()), list(train_entity_id_name.keys())
|
|
train_entities = [preprocess_text(element) for element in train_entities]
|
|
|
|
def batch_list(data, batch_size):
|
|
"""Yield successive n-sized chunks from data."""
|
|
for i in range(0, len(data), batch_size):
|
|
yield data[i:i + batch_size]
|
|
|
|
batches = batch_list(train_entities, 64)
|
|
|
|
embedding_list = []
|
|
for batch in batches:
|
|
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
|
outputs = model(
|
|
input_ids=inputs['input_ids'].to(DEVICE),
|
|
attention_mask=inputs['attention_mask'].to(DEVICE)
|
|
)
|
|
output = outputs.last_hidden_state[:,0,:]
|
|
output = output.detach().cpu().numpy()
|
|
embedding_list.append(output)
|
|
|
|
cls = np.concatenate(embedding_list)
|
|
# %%
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# %%
|
|
|
|
batches = batch_list(x_test, 64)
|
|
|
|
embedding_list = []
|
|
for batch in batches:
|
|
inputs = tokenizer(batch, padding=True, return_tensors='pt')
|
|
outputs = model(
|
|
input_ids=inputs['input_ids'].to(DEVICE),
|
|
attention_mask=inputs['attention_mask'].to(DEVICE)
|
|
)
|
|
output = outputs.last_hidden_state[:,0,:]
|
|
output = output.detach().cpu().numpy()
|
|
embedding_list.append(output)
|
|
|
|
cls_test = np.concatenate(embedding_list)
|
|
|
|
|
|
# %%
|
|
knn = KNeighborsClassifier(n_neighbors=1, metric='cosine').fit(cls, labels)
|
|
n_neighbors = [1, 3, 5, 10]
|
|
|
|
|
|
with open("results/output.txt", "w") as f:
|
|
for n in n_neighbors:
|
|
distances, indices = knn.kneighbors(cls_test, n_neighbors=n)
|
|
num = 0
|
|
for a,b in zip(y_test, indices):
|
|
b = [labels[i] for i in b]
|
|
if a in b:
|
|
num += 1
|
|
print(f'Top-{n:<3} accuracy: {num / len(y_test)}', file=f)
|
|
print(np.min(distances), np.max(distances), file=f)
|
|
|
|
# %%
|