117 lines
17 KiB
Plaintext
117 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=1: 84.43%\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6311\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6312\u001b[0;31m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;31mAttributeError\u001b[0m: 'Series' object has no attribute '_name'",
|
|
"\nDuring handling of the above exception, another exception occurred:\n",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m/tmp/ipykernel_89094/2696322053.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mknn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkneighbors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_bow_matrix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mpredicted_things\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain_all_csv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'thing'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0mpredicted_properties\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain_all_csv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'property'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0mpredicted_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_thing'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_property'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_score'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredicted_things\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted_properties\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted_scores\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/tmp/ipykernel_89094/2696322053.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpandas\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1187\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1189\u001b[0m \u001b[0mmaybe_callable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_if_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[0mmaybe_callable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_deprecated_callable_usage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaybe_callable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmaybe_callable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1750\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0;31m# validate the location\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_integer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1753\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1754\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ixs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, i, axis)\u001b[0m\n\u001b[1;32m 3996\u001b[0m \u001b[0mnew_mgr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfast_xs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3997\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3998\u001b[0m \u001b[0;31m# if we are a copy, mark as such\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3999\u001b[0m \u001b[0mcopy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4000\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_constructor_sliced_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4001\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4002\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__finalize__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4003\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_is_copy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, mgr, axes)\u001b[0m\n\u001b[1;32m 678\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_constructor_sliced_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 679\u001b[0m \u001b[0mser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 680\u001b[0;31m \u001b[0mser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# caller is responsible for setting real name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 681\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mDataFrame\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 683\u001b[0m \u001b[0;31m# This would also work `if self._constructor_sliced is Series`, but\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6308\u001b[0m \u001b[0;31m# e.g. ``obj.x`` and ``obj.x = 4`` will always reference/modify\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6309\u001b[0m \u001b[0;31m# the same attribute.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6311\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6312\u001b[0;31m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6314\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6315\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
|
"from sklearn.neighbors import NearestNeighbors\n",
|
|
"import os\n",
|
|
"\n",
|
|
"average_accuracies = []\n",
|
|
"\n",
|
|
"for n in range(1, 53):\n",
|
|
" accuracies = []\n",
|
|
" for group_number in range(1, 6):\n",
|
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
|
"\n",
|
|
" if not os.path.exists(test_path):\n",
|
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
|
" continue\n",
|
|
"\n",
|
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
|
"\n",
|
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
|
"\n",
|
|
" test_csv['c_thing'], test_csv['c_property'], test_csv['c_score'], test_csv['c_duplicate'] = '', '', '', 0\n",
|
|
"\n",
|
|
" vectorizer = CountVectorizer(token_pattern=r'\\S+', ngram_range=(1, 1))\n",
|
|
" train_all_bow_matrix = vectorizer.fit_transform(train_all_csv['tag_description'])\n",
|
|
" test_bow_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
|
"\n",
|
|
" knn = NearestNeighbors(n_neighbors=n, metric='euclidean', n_jobs=-1)\n",
|
|
" knn.fit(train_all_bow_matrix)\n",
|
|
"\n",
|
|
" distances, indices = knn.kneighbors(test_bow_matrix)\n",
|
|
"\n",
|
|
" predicted_things = [train_all_csv.iloc[indices[i][0]]['thing'] for i in range(len(test_csv))]\n",
|
|
" predicted_properties = [train_all_csv.iloc[indices[i][0]]['property'] for i in range(len(test_csv))]\n",
|
|
" predicted_scores = [1 - distances[i][0] for i in range(len(test_csv))]\n",
|
|
"\n",
|
|
" test_csv['c_thing'], test_csv['c_property'], test_csv['c_score'] = predicted_things, predicted_properties, predicted_scores\n",
|
|
"\n",
|
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
|
"\n",
|
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
|
" accuracies.append((test_csv['ctp_correct'].sum() / mdm_true_count) * 100)\n",
|
|
"\n",
|
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
|
" average_accuracies.append(average_accuracy)\n",
|
|
" print(f\"Average Accuracy (MDM=True) across all groups with n_neighbors={n}: {average_accuracy:.2f}%\")\n",
|
|
"\n",
|
|
"print(\"\\nFinal Results:\")\n",
|
|
"for n, avg_accuracy in zip(range(1, 53), average_accuracies):\n",
|
|
" print(f\"n_neighbors={n}, Average Accuracy: {avg_accuracy:.2f}%\")\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "torch",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|