domain_mapping/analysis/bert_label_clustering.py

81 lines
2.1 KiB
Python
Raw Normal View History

# %%
import json
import pandas as pd
from utils import Retriever, cosine_similarity_chunked
from sklearn.metrics.pairwise import cosine_similarity
##########################################
# %%
# Load the JSON file
data_path = '../esAppMod/tca_entities.json'
with open(data_path, 'r') as file:
data = json.load(file)
# Initialize an empty list to store the rows
rows = []
# %%
# Loop through all entities in the JSON
for entity in data["data"].items():
entity_data = entity[1]
entity_id = entity_data['entity_id']
entity_name = entity_data['entity_name']
# Add each mention and its entity_id to the rows list
rows.append({"id": entity_id, "name": entity_name})
# Create a DataFrame from the rows
df = pd.DataFrame(rows)
# %%
# df.to_csv('entity.csv', index=False)
# %%
# we want to automatically identify clusters
class Embedder():
input_df: pd.DataFrame
fold: int
def __init__(self, input_df):
self.input_df = input_df
def make_embedding(self, checkpoint_path):
def generate_input_list(df):
input_list = []
for _, row in df.iterrows():
desc = row['name']
input_list.append(desc)
return input_list
# prepare reference embed
train_data = list(generate_input_list(self.input_df))
# Define the directory and the pattern
retriever_train = Retriever(train_data, checkpoint_path)
retriever_train.make_embedding(batch_size=64)
return retriever_train.embeddings.to('cpu')
# model_checkpoint = 'google-bert/bert-base-cased'
model_checkpoint = '../train/class_bert_simple/checkpoint/checkpoint-4500'
embedder = Embedder(input_df=df)
embeddings = embedder.make_embedding(model_checkpoint)
# %%
similarity_matrix = cosine_similarity(embeddings)
# %%
similarity_matrix.shape
# %%
from sklearn.cluster import AgglomerativeClustering
clustering = AgglomerativeClustering(metric='precomputed', linkage='average')
clustering.fit(1 - similarity_matrix) # Use distance = 1 - similarity
print(clustering.labels_) # Cluster assignments
# %%