hipom_data_mapping/evaluation/check_accuracy.ipynb

142 lines
5.3 KiB
Plaintext
Raw Normal View History

2024-08-26 19:51:11 +09:00
{
"cells": [
{
"cell_type": "code",
2024-09-25 08:52:30 +09:00
"execution_count": 6,
2024-08-26 19:51:11 +09:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-09-25 08:52:30 +09:00
"Performance for group 1 (test_s.csv):\n",
"TP: 1794, TN: 9954, FP: 1005, FN: 319\n",
"Precision: 0.6409, Recall: 0.8490, Accuracy: 0.8987, F1-Score: 0.7305\n",
"--------------------------------------------------\n",
"Performance for group 2 (test_s.csv):\n",
"TP: 1824, TN: 7716, FP: 866, FN: 316\n",
"Precision: 0.6781, Recall: 0.8523, Accuracy: 0.8898, F1-Score: 0.7553\n",
"--------------------------------------------------\n",
"Performance for group 3 (test_s.csv):\n",
"TP: 1804, TN: 6866, FP: 996, FN: 188\n",
"Precision: 0.6443, Recall: 0.9056, Accuracy: 0.8798, F1-Score: 0.7529\n",
"--------------------------------------------------\n",
"Performance for group 4 (test_s.csv):\n",
"TP: 1916, TN: 12360, FP: 989, FN: 186\n",
"Precision: 0.6596, Recall: 0.9115, Accuracy: 0.9240, F1-Score: 0.7653\n",
"--------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_2997916/1903646223.py:38: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value '' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n",
" test_s_csv.fillna('', inplace=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Performance for group 5 (test_s.csv):\n",
"TP: 1910, TN: 9800, FP: 955, FN: 273\n",
"Precision: 0.6667, Recall: 0.8749, Accuracy: 0.9051, F1-Score: 0.7567\n",
"--------------------------------------------------\n",
"Average performance across all groups:\n",
"Average Precision: 0.6579\n",
"Average Recall: 0.8787\n",
"Average Accuracy: 0.8995\n",
"Average F1-Score: 0.7521\n"
2024-08-26 19:51:11 +09:00
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"def evaluate_performance(test_csv):\n",
" TP = 0\n",
" TN = 0\n",
" FP = 0\n",
" FN = 0\n",
"\n",
" for index, row in test_csv.iterrows():\n",
" if row['s_correct'] and row['MDM']:\n",
" TP += 1\n",
" elif row['s_thing'] == '' and not row['MDM']:\n",
" TN += 1\n",
2024-09-25 08:52:30 +09:00
" elif (row['s_thing'] != '' and not row['MDM']):\n",
2024-08-26 19:51:11 +09:00
" FP += 1\n",
2024-09-25 08:52:30 +09:00
" elif row['s_thing'] == '' and row['MDM'] or (row['s_thing'] != '' and not row['s_correct'] and row['MDM']):\n",
2024-08-26 19:51:11 +09:00
" FN += 1\n",
"\n",
" total = TP + TN + FP + FN\n",
"\n",
" precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n",
" recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n",
" accuracy = (TP + TN) / total if total > 0 else 0\n",
2024-09-25 08:52:30 +09:00
" f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n",
"\n",
" return TP, TN, FP, FN, precision, recall, accuracy, f1_score\n",
"\n",
"# Lists to store performance metrics for all folds\n",
"all_precisions = []\n",
"all_recalls = []\n",
"all_accuracies = []\n",
"all_f1_scores = []\n",
"\n",
"# Perform evaluation for group 1 to 5\n",
"for group_number in range(1, 6):\n",
" test_s_path = f'../post_process/0.result/{group_number}/test_s.csv'\n",
" test_s_csv = pd.read_csv(test_s_path, low_memory=False)\n",
" test_s_csv.fillna('', inplace=True)\n",
"\n",
" tp_s_results = evaluate_performance(test_s_csv)\n",
"\n",
" print(f\"Performance for group {group_number} (test_s.csv):\")\n",
" print(f\"TP: {tp_s_results[0]}, TN: {tp_s_results[1]}, FP: {tp_s_results[2]}, FN: {tp_s_results[3]}\")\n",
" print(f\"Precision: {tp_s_results[4]:.4f}, Recall: {tp_s_results[5]:.4f}, Accuracy: {tp_s_results[6]:.4f}, F1-Score: {tp_s_results[7]:.4f}\")\n",
" print(\"-\" * 50)\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" all_precisions.append(tp_s_results[4])\n",
" all_recalls.append(tp_s_results[5])\n",
" all_accuracies.append(tp_s_results[6])\n",
" all_f1_scores.append(tp_s_results[7])\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"# Calculate and print the averages across all groups\n",
"average_precision = sum(all_precisions) / len(all_precisions)\n",
"average_recall = sum(all_recalls) / len(all_recalls)\n",
"average_accuracy = sum(all_accuracies) / len(all_accuracies)\n",
"average_f1_score = sum(all_f1_scores) / len(all_f1_scores)\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"print(\"Average performance across all groups:\")\n",
"print(f\"Average Precision: {average_precision:.4f}\")\n",
"print(f\"Average Recall: {average_recall:.4f}\")\n",
"print(f\"Average Accuracy: {average_accuracy:.4f}\")\n",
"print(f\"Average F1-Score: {average_f1_score:.4f}\")\n"
2024-08-26 19:51:11 +09:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}