73 lines
1.9 KiB
Python
73 lines
1.9 KiB
Python
# %%
|
|
import pandas as pd
|
|
|
|
# %%
|
|
# import training file
|
|
data_path = '../esAppMod_data_import/train.csv'
|
|
# data_path = '../esAppMod_data_import/parent_train.csv'
|
|
train_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
|
|
# import test file
|
|
data_path = '../esAppMod_data_import/test.csv'
|
|
# data_path = '../esAppMod_data_import/parent_test.csv'
|
|
test_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
|
|
# import entity file
|
|
data_path = '../esAppMod_data_import/entity.csv'
|
|
entity_df = pd.read_csv(data_path, skipinitialspace=True)
|
|
id2label = {}
|
|
for _, row in entity_df.iterrows():
|
|
id2label[row['id']] = row['name']
|
|
|
|
train_df.sort_values(by=['entity_id']).to_markdown('out.md')
|
|
|
|
# %%
|
|
data_path = '../train/class_bert_augmentation/prediction/exports/result.csv'
|
|
prediction_df = pd.read_csv(data_path)
|
|
|
|
predicted_entity_list = []
|
|
for element in prediction_df['class_prediction']:
|
|
predicted_entity_list.append(id2label[element])
|
|
|
|
prediction_df['predicted_name'] = predicted_entity_list
|
|
new_df = pd.concat((test_df, prediction_df ), axis=1)
|
|
mismatch_mask = new_df['entity_id'] != new_df['class_prediction']
|
|
mismatch_df = new_df[mismatch_mask]
|
|
len(mismatch_df)
|
|
|
|
# %%
|
|
# print the top 10 offending classes
|
|
# mask1 = mismatch_df['entity_id'] != 434
|
|
# mask2 = mismatch_df['entity_id'] != 451
|
|
# mask3 = mismatch_df['entity_id'] != 452
|
|
# mask= mask1 & mask2 & mask3
|
|
# masked_df = mismatch_df[mask]
|
|
# print(masked_df['entity_id'].value_counts()[:10])
|
|
print(mismatch_df['entity_id'].value_counts()[:10])
|
|
masked_df = mismatch_df
|
|
|
|
|
|
# %%
|
|
# Convert the whole dataframe as a string and display
|
|
# print the mismatch_df
|
|
print(masked_df.sort_values(by=['entity_id']).to_markdown())
|
|
|
|
# %%
|
|
mismatch_df.to_csv('error.csv')
|
|
|
|
# %%
|
|
# let us see the test mentions
|
|
select_value = 268
|
|
select_mask = mismatch_df['entity_id'] == select_value
|
|
mismatch_df[select_mask]
|
|
|
|
# %%
|
|
# let us see the train mentions
|
|
select_value = 130
|
|
select_mask = train_df['entity_id'] == select_value
|
|
train_df[select_mask]
|
|
|
|
|
|
|