domain_mapping/analysis/error_analysis.py

76 lines
1.6 KiB
Python
Raw Normal View History

# %%
import pandas as pd
# %%
# import training file
data_path = '../data_import/train.csv'
train_df = pd.read_csv(data_path, skipinitialspace=True)
# import test file
data_path = '../data_import/test.csv'
test_df = pd.read_csv(data_path, skipinitialspace=True)
# import entity file
data_path = '../data_import/entity.csv'
entity_df = pd.read_csv(data_path, skipinitialspace=True)
id2label = {}
for _, row in entity_df.iterrows():
id2label[row['id']] = row['name']
# %%
train_df.sort_values(by=['entity_id']).to_markdown('out.md')
# %%
data_path = '../train/class_bert_process/prediction/exports/result.csv'
prediction_df = pd.read_csv(data_path)
# %%
predicted_entity_list = []
for element in prediction_df['class_prediction']:
predicted_entity_list.append(id2label[element])
prediction_df['predicted_name'] = predicted_entity_list
# %%
new_df = pd.concat((test_df, prediction_df ), axis=1)
# %%
mismatch_mask = new_df['entity_id'] != new_df['class_prediction']
mismatch_df = new_df[mismatch_mask]
# %%
len(mismatch_df)
# %%
# print the top 10 offending classes
print(mismatch_df['entity_id'].value_counts()[:10])
# %%
# Convert the whole dataframe as a string and display
# print the mismatch_df
print(mismatch_df.sort_values(by=['entity_id']).to_markdown())
# %%
mismatch_df.to_csv('error.csv')
# %%
# let us see the test mentions
select_value = 268
select_mask = mismatch_df['entity_id'] == select_value
mismatch_df[select_mask]
# %%
# let us see the train mentions
select_value = 452
select_mask = train_df['entity_id'] == select_value
train_df[select_mask]
# %%
mismatch_df[select_mask]['class_prediction'].to_list()
# %%
# %%