domain_mapping/analysis/error_analysis_baseline.py

57 lines
1.6 KiB
Python

# %%
import pandas as pd
import json
# %%
data_path = '../loss_comparisons_without_augmentation/results/predictions.txt'
df = pd.read_csv(data_path, header=None)
df = df.rename(columns={0: 'actual', 1: 'predicted'})
# %%
with open('../esAppMod/tca_entities.json', 'r') as file:
entities = json.load(file)
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
with open('../esAppMod/train.json', 'r') as file:
train = json.load(file)
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
# %%
df['predicted_name'] = df['predicted'].map(all_entity_id_name)
# %%
# import test file
data_path = '../esAppMod_data_import/test.csv'
# data_path = '../esAppMod_data_import/parent_test.csv'
test_df = pd.read_csv(data_path)
# %%
df_out = pd.concat([test_df,df], axis=1)
# %%
mask1 = (df['predicted'] != df['actual'])
# %%
print(df_out[mask1].sort_values(by=['entity_id']).to_markdown())
# %%
data_path = '../loss_comparisons_with_augmentations/results/predictions.txt'
df2 = pd.read_csv(data_path, header=None)
df2 = df2.rename(columns={0: 'actual', 1: 'predicted'})
mask2 = df2['actual'] != df2['predicted']
# %%
# i want to find entries that were:
# - correct in mask1
# - wrong in mask2
mask_left = ~mask1 & mask2
predicted_entity = df2['predicted'].map(all_entity_id_name)
df_out = pd.concat([test_df,df2, predicted_entity], axis=1)
print(df_out[mask_left].sort_values(by=['entity_id']).to_markdown())
# %%