57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
|
# %%
|
||
|
import pandas as pd
|
||
|
import json
|
||
|
|
||
|
# %%
|
||
|
data_path = '../loss_comparisons_without_augmentation/results/predictions.txt'
|
||
|
df = pd.read_csv(data_path, header=None)
|
||
|
df = df.rename(columns={0: 'actual', 1: 'predicted'})
|
||
|
|
||
|
# %%
|
||
|
with open('../esAppMod/tca_entities.json', 'r') as file:
|
||
|
entities = json.load(file)
|
||
|
all_entity_id_name = {entity['entity_id']: entity['entity_name'] for _, entity in entities['data'].items()}
|
||
|
|
||
|
with open('../esAppMod/train.json', 'r') as file:
|
||
|
train = json.load(file)
|
||
|
train_entity_id_mentions = {data['entity_id']: data['mentions'] for _, data in train['data'].items()}
|
||
|
train_entity_id_name = {data['entity_id']: all_entity_id_name[data['entity_id']] for _, data in train['data'].items()}
|
||
|
|
||
|
# %%
|
||
|
df['predicted_name'] = df['predicted'].map(all_entity_id_name)
|
||
|
|
||
|
# %%
|
||
|
# import test file
|
||
|
data_path = '../esAppMod_data_import/test.csv'
|
||
|
# data_path = '../esAppMod_data_import/parent_test.csv'
|
||
|
test_df = pd.read_csv(data_path)
|
||
|
|
||
|
|
||
|
|
||
|
# %%
|
||
|
df_out = pd.concat([test_df,df], axis=1)
|
||
|
|
||
|
# %%
|
||
|
mask1 = (df['predicted'] != df['actual'])
|
||
|
# %%
|
||
|
|
||
|
print(df_out[mask1].sort_values(by=['entity_id']).to_markdown())
|
||
|
# %%
|
||
|
|
||
|
data_path = '../loss_comparisons_with_augmentations/results/predictions.txt'
|
||
|
df2 = pd.read_csv(data_path, header=None)
|
||
|
df2 = df2.rename(columns={0: 'actual', 1: 'predicted'})
|
||
|
mask2 = df2['actual'] != df2['predicted']
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# i want to find entries that were:
|
||
|
# - correct in mask1
|
||
|
# - wrong in mask2
|
||
|
mask_left = ~mask1 & mask2
|
||
|
|
||
|
predicted_entity = df2['predicted'].map(all_entity_id_name)
|
||
|
df_out = pd.concat([test_df,df2, predicted_entity], axis=1)
|
||
|
print(df_out[mask_left].sort_values(by=['entity_id']).to_markdown())
|
||
|
# %%
|