hipom_data_mapping/post_process/tfidf_class/2a.classifier_class_tfidf.i...

117 lines
4.4 KiB
Plaintext
Raw Normal View History

2024-08-26 19:51:11 +09:00
{
"cells": [
{
"cell_type": "code",
2024-09-25 08:52:30 +09:00
"execution_count": 4,
2024-08-26 19:51:11 +09:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-09-25 08:52:30 +09:00
"Accuracy (MDM=True) for Group 1: 73.50%\n",
"Accuracy (MDM=True) for Group 2: 78.04%\n",
"Accuracy (MDM=True) for Group 3: 81.73%\n",
"Accuracy (MDM=True) for Group 4: 79.83%\n",
"Accuracy (MDM=True) for Group 5: 81.31%\n",
"Average Accuracy (MDM=True) across all groups: 78.88%\n"
2024-08-26 19:51:11 +09:00
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
2024-09-25 08:52:30 +09:00
"from sklearn.metrics import pairwise_distances\n",
2024-08-26 19:51:11 +09:00
"from tqdm import tqdm\n",
"import os\n",
"\n",
"accuracies = []\n",
"\n",
"for group_number in range(1, 6):\n",
" \n",
" sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
2024-09-25 08:52:30 +09:00
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
2024-08-26 19:51:11 +09:00
" \n",
" if not os.path.exists(test_path):\n",
" print(f\"test file for Group {group_number} does not exist. Skipping...\")\n",
" continue\n",
" \n",
" sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
" \n",
" sdl_class_rdoc_csv['tag_description'] = sdl_class_rdoc_csv['tag_description'].fillna('')\n",
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
" \n",
" test_csv['c_thing'] = ''\n",
" test_csv['c_property'] = ''\n",
" test_csv['c_score'] = ''\n",
2024-09-25 08:52:30 +09:00
" test_csv['c_duplicate'] = 0\n",
2024-08-26 19:51:11 +09:00
" \n",
" combined_tag_descriptions = sdl_class_rdoc_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
" \n",
" vectorizer = TfidfVectorizer(\n",
2024-09-25 08:52:30 +09:00
" use_idf=True, \n",
2024-08-26 19:51:11 +09:00
" token_pattern=r'\\S+',\n",
2024-09-25 08:52:30 +09:00
" ngram_range=(1, 1),\n",
2024-08-26 19:51:11 +09:00
" )\n",
" \n",
" vectorizer.fit(combined_tag_descriptions)\n",
" \n",
" sdl_class_rdoc_tfidf_matrix = vectorizer.transform(sdl_class_rdoc_csv['tag_description'])\n",
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description'])\n",
" \n",
2024-09-25 08:52:30 +09:00
" distance_matrix = pairwise_distances(test_tfidf_matrix, sdl_class_rdoc_tfidf_matrix, metric='cosine')\n",
2024-08-26 19:51:11 +09:00
" \n",
2024-09-25 08:52:30 +09:00
" most_similar_indices = distance_matrix.argmin(axis=1)\n",
" most_similar_scores = 1 - distance_matrix.min(axis=1)\n",
2024-08-26 19:51:11 +09:00
" \n",
" test_csv['c_thing'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['thing'].values\n",
" test_csv['c_property'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['property'].values\n",
" test_csv['c_score'] = most_similar_scores\n",
" \n",
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
" \n",
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
" accuracies.append(accuracy)\n",
" \n",
" print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
" \n",
" output_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
" test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
" \n",
" false_positive_rows = test_csv[(test_csv['MDM'] == True) & (test_csv['ctp_correct'] == False)]\n",
" \n",
" fp_output_path = f'0.class_document/{group_number}/fp_class.csv'\n",
" false_positive_rows.to_csv(fp_output_path, index=False, encoding='utf-8-sig')\n",
"\n",
"average_accuracy = sum(accuracies) / len(accuracies)\n",
"print(f\"Average Accuracy (MDM=True) across all groups: {average_accuracy:.2f}%\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}