# this code tries to analyze the embeddings of the encoder # %% import pandas as pd import os from inference import Embedder_t5 import numpy as np from sklearn.manifold import TSNE import matplotlib.pyplot as plt checkpoint_directory = 'mapping_t5_complete_desc_unit/checkpoint' BATCH_SIZE = 512 fold = 1 print(f"Inference for fold {fold}") # import test data data_path = f"../data_preprocess/exports/dataset/group_{fold}/test_all.csv" df = pd.read_csv(data_path, skipinitialspace=True) df = df[df['MDM']].reset_index(drop=True) # get target data data_path = f"../data_preprocess/exports/dataset/group_{fold}/train_all.csv" train_df = pd.read_csv(data_path, skipinitialspace=True) # processing to help with selection later train_df['thing_property'] = train_df['thing'] + " " + train_df['property'] # assign labels df['thing_property'] = df['thing'] + " " + df['property'] thing_property = df['thing_property'].to_list() mdm_list = sorted(list(set(thing_property))) def generate_labels(df, mdm_list): output_list = [] for _, row in df.iterrows(): pattern = f"{row['thing_property']}" try: index = mdm_list.index(pattern) except ValueError: print("Error: value not found in MDM list") index = -1 output_list.append(index) return output_list df['labels'] = generate_labels(df, mdm_list) # rank labels by counts top_10_labels = df['labels'].value_counts()[0:10].index.to_list() indices = df[df['labels'].isin(top_10_labels)].index.to_list() input_df = df.iloc[indices].reset_index(drop=True) # %% def run(step): checkpoint_path = os.path.join(checkpoint_directory, f'checkpoint_{step}') embedder = Embedder_t5(checkpoint_path) embedder.prepare_dataloader(input_df, batch_size=BATCH_SIZE, max_length=128) embedder.create_embedding() embeddings = embedder.embeddings return embeddings # %% embeddings = (run(step=1200)) labels = input_df['labels'] # Reducing dimensions with t-SNE tsne = TSNE(n_components=2, random_state=0, perplexity=5) embeddings_2d = tsne.fit_transform(embeddings) # Create a color map from labels to colors unique_labels = np.unique(labels) colors = plt.cm.jet(np.linspace(0, 1, len(unique_labels))) label_to_color = dict(zip(unique_labels, colors)) # Plotting plt.figure(figsize=(8, 6)) for label in unique_labels: idx = (labels == label) plt.scatter(embeddings_2d[idx, 0], embeddings_2d[idx, 1], color=label_to_color[label], label=label, alpha=0.7) plt.title('2D t-SNE Visualization of Embeddings') plt.xlabel('Component 1') plt.ylabel('Component 2') plt.legend(title='Group') plt.show() # %%