hipom_data_mapping/post_process/tfidf_class/2a.classifier_knn_bow.ipynb

117 lines
17 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average Accuracy (MDM=True) across all groups with n_neighbors=1: 84.43%\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6311\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6312\u001b[0;31m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'Series' object has no attribute '_name'",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_89094/2696322053.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mknn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkneighbors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_bow_matrix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mpredicted_things\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain_all_csv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'thing'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0mpredicted_properties\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain_all_csv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'property'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0mpredicted_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_thing'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_property'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_score'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredicted_things\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted_properties\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted_scores\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/tmp/ipykernel_89094/2696322053.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpandas\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1187\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1189\u001b[0m \u001b[0mmaybe_callable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_if_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[0mmaybe_callable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_deprecated_callable_usage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaybe_callable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmaybe_callable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1750\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0;31m# validate the location\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_integer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1753\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1754\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ixs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, i, axis)\u001b[0m\n\u001b[1;32m 3996\u001b[0m \u001b[0mnew_mgr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfast_xs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3997\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3998\u001b[0m \u001b[0;31m# if we are a copy, mark as such\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3999\u001b[0m \u001b[0mcopy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4000\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_constructor_sliced_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4001\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4002\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__finalize__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4003\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_is_copy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, mgr, axes)\u001b[0m\n\u001b[1;32m 678\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_constructor_sliced_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 679\u001b[0m \u001b[0mser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 680\u001b[0;31m \u001b[0mser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# caller is responsible for setting real name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 681\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mDataFrame\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 683\u001b[0m \u001b[0;31m# This would also work `if self._constructor_sliced is Series`, but\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6308\u001b[0m \u001b[0;31m# e.g. ``obj.x`` and ``obj.x = 4`` will always reference/modify\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6309\u001b[0m \u001b[0;31m# the same attribute.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6311\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6312\u001b[0;31m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6314\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6315\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from sklearn.neighbors import NearestNeighbors\n",
"import os\n",
"\n",
"average_accuracies = []\n",
"\n",
"for n in range(1, 53):\n",
" accuracies = []\n",
" for group_number in range(1, 6):\n",
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
"\n",
" if not os.path.exists(test_path):\n",
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
" continue\n",
"\n",
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
"\n",
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
"\n",
" test_csv['c_thing'], test_csv['c_property'], test_csv['c_score'], test_csv['c_duplicate'] = '', '', '', 0\n",
"\n",
" vectorizer = CountVectorizer(token_pattern=r'\\S+', ngram_range=(1, 1))\n",
" train_all_bow_matrix = vectorizer.fit_transform(train_all_csv['tag_description'])\n",
" test_bow_matrix = vectorizer.transform(test_csv['tag_description'])\n",
"\n",
" knn = NearestNeighbors(n_neighbors=n, metric='euclidean', n_jobs=-1)\n",
" knn.fit(train_all_bow_matrix)\n",
"\n",
" distances, indices = knn.kneighbors(test_bow_matrix)\n",
"\n",
" predicted_things = [train_all_csv.iloc[indices[i][0]]['thing'] for i in range(len(test_csv))]\n",
" predicted_properties = [train_all_csv.iloc[indices[i][0]]['property'] for i in range(len(test_csv))]\n",
" predicted_scores = [1 - distances[i][0] for i in range(len(test_csv))]\n",
"\n",
" test_csv['c_thing'], test_csv['c_property'], test_csv['c_score'] = predicted_things, predicted_properties, predicted_scores\n",
"\n",
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
"\n",
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
" accuracies.append((test_csv['ctp_correct'].sum() / mdm_true_count) * 100)\n",
"\n",
" average_accuracy = sum(accuracies) / len(accuracies)\n",
" average_accuracies.append(average_accuracy)\n",
" print(f\"Average Accuracy (MDM=True) across all groups with n_neighbors={n}: {average_accuracy:.2f}%\")\n",
"\n",
"print(\"\\nFinal Results:\")\n",
"for n, avg_accuracy in zip(range(1, 53), average_accuracies):\n",
" print(f\"n_neighbors={n}, Average Accuracy: {avg_accuracy:.2f}%\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}