81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
# %%
|
|
import json
|
|
import pandas as pd
|
|
from utils import Retriever, cosine_similarity_chunked
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
##########################################
|
|
# %%
|
|
|
|
# Load the JSON file
|
|
data_path = '../esAppMod/tca_entities.json'
|
|
with open(data_path, 'r') as file:
|
|
data = json.load(file)
|
|
|
|
# Initialize an empty list to store the rows
|
|
rows = []
|
|
|
|
# %%
|
|
# Loop through all entities in the JSON
|
|
for entity in data["data"].items():
|
|
entity_data = entity[1]
|
|
entity_id = entity_data['entity_id']
|
|
entity_name = entity_data['entity_name']
|
|
|
|
# Add each mention and its entity_id to the rows list
|
|
rows.append({"id": entity_id, "name": entity_name})
|
|
|
|
# Create a DataFrame from the rows
|
|
df = pd.DataFrame(rows)
|
|
|
|
|
|
# %%
|
|
# df.to_csv('entity.csv', index=False)
|
|
|
|
|
|
# %%
|
|
# we want to automatically identify clusters
|
|
class Embedder():
|
|
input_df: pd.DataFrame
|
|
fold: int
|
|
|
|
def __init__(self, input_df):
|
|
self.input_df = input_df
|
|
|
|
|
|
def make_embedding(self, checkpoint_path):
|
|
|
|
def generate_input_list(df):
|
|
input_list = []
|
|
for _, row in df.iterrows():
|
|
desc = row['name']
|
|
input_list.append(desc)
|
|
return input_list
|
|
|
|
# prepare reference embed
|
|
train_data = list(generate_input_list(self.input_df))
|
|
# Define the directory and the pattern
|
|
retriever_train = Retriever(train_data, checkpoint_path)
|
|
retriever_train.make_embedding(batch_size=64)
|
|
return retriever_train.embeddings.to('cpu')
|
|
|
|
# model_checkpoint = 'google-bert/bert-base-cased'
|
|
model_checkpoint = '../train/class_bert_simple/checkpoint/checkpoint-4500'
|
|
embedder = Embedder(input_df=df)
|
|
embeddings = embedder.make_embedding(model_checkpoint)
|
|
|
|
# %%
|
|
similarity_matrix = cosine_similarity(embeddings)
|
|
|
|
# %%
|
|
similarity_matrix.shape
|
|
|
|
# %%
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
|
|
clustering = AgglomerativeClustering(metric='precomputed', linkage='average')
|
|
clustering.fit(1 - similarity_matrix) # Use distance = 1 - similarity
|
|
|
|
print(clustering.labels_) # Cluster assignments
|
|
# %%
|