135 lines
5.8 KiB
Plaintext
135 lines
5.8 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Accuracy (MDM=True) for Group 1: 79.41%\n",
|
||
|
"Accuracy (MDM=True) for Group 2: 79.32%\n",
|
||
|
"Accuracy (MDM=True) for Group 3: 82.49%\n",
|
||
|
"Accuracy (MDM=True) for Group 4: 85.61%\n",
|
||
|
"Accuracy (MDM=True) for Group 5: 79.72%\n",
|
||
|
"Average Accuracy (MDM=True) across all groups: 81.31%\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||
|
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
||
|
"from tqdm import tqdm\n",
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"# Initialize a list to store the accuracies for each group\n",
|
||
|
"accuracies = []\n",
|
||
|
"\n",
|
||
|
"# Loop through group numbers from 1 to 5\n",
|
||
|
"for group_number in range(1, 6):\n",
|
||
|
" \n",
|
||
|
" # Load the CSV files from the specified group\n",
|
||
|
" sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
|
||
|
" test_path = f'../../data_preprocess/dataset/{group_number}/test.csv'\n",
|
||
|
" \n",
|
||
|
" # Check if test file exists, if not, skip this iteration\n",
|
||
|
" if not os.path.exists(test_path):\n",
|
||
|
" print(f\"test file for Group {group_number} does not exist. Skipping...\")\n",
|
||
|
" continue\n",
|
||
|
" \n",
|
||
|
" sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
|
||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||
|
" \n",
|
||
|
" # Replace NaN values with empty strings in relevant columns\n",
|
||
|
" sdl_class_rdoc_csv['tag_description'] = sdl_class_rdoc_csv['tag_description'].fillna('')\n",
|
||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||
|
" \n",
|
||
|
" # Initialize new columns in test_csv\n",
|
||
|
" test_csv['c_thing'] = ''\n",
|
||
|
" test_csv['c_property'] = ''\n",
|
||
|
" test_csv['c_score'] = ''\n",
|
||
|
" test_csv['c_duplicate'] = 0 # Initialize c_duplicate to store duplicate counts\n",
|
||
|
" \n",
|
||
|
" # Combine both sdl_class_rdoc and test CSVs tag_descriptions for TF-IDF Vectorizer training\n",
|
||
|
" combined_tag_descriptions = sdl_class_rdoc_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
|
||
|
" \n",
|
||
|
" # Create a TF-IDF Vectorizer\n",
|
||
|
" vectorizer = TfidfVectorizer(\n",
|
||
|
" token_pattern=r'\\S+',\n",
|
||
|
" ngram_range=(1, 6), # Use ngrams from 1 to 6\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" # Fit the TF-IDF vectorizer on the combined tag_descriptions\n",
|
||
|
" vectorizer.fit(combined_tag_descriptions)\n",
|
||
|
" \n",
|
||
|
" # Transform both sdl_class_rdoc and test CSVs into TF-IDF matrices\n",
|
||
|
" sdl_class_rdoc_tfidf_matrix = vectorizer.transform(sdl_class_rdoc_csv['tag_description'])\n",
|
||
|
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
||
|
" \n",
|
||
|
" # Calculate cosine similarity between test and class-level sdl_class_rdoc vectors\n",
|
||
|
" similarity_matrix = cosine_similarity(test_tfidf_matrix, sdl_class_rdoc_tfidf_matrix)\n",
|
||
|
" \n",
|
||
|
" # Find the most similar class-level tag_description for each test description\n",
|
||
|
" most_similar_indices = similarity_matrix.argmax(axis=1)\n",
|
||
|
" most_similar_scores = similarity_matrix.max(axis=1)\n",
|
||
|
" \n",
|
||
|
" # Assign the corresponding thing, property, and similarity score to the test CSV\n",
|
||
|
" test_csv['c_thing'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['thing'].values\n",
|
||
|
" test_csv['c_property'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['property'].values\n",
|
||
|
" test_csv['c_score'] = most_similar_scores\n",
|
||
|
" \n",
|
||
|
" # Check if the predicted 'c_thing' and 'c_property' match the actual 'thing' and 'property'\n",
|
||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||
|
" \n",
|
||
|
" # Calculate accuracy based only on MDM = True\n",
|
||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||
|
" accuracies.append(accuracy)\n",
|
||
|
" \n",
|
||
|
" print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||
|
" \n",
|
||
|
" # Specify output file paths\n",
|
||
|
" output_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
||
|
" test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||
|
" \n",
|
||
|
" # Filter for rows where MDM is True and ctp_correct is False\n",
|
||
|
" false_positive_rows = test_csv[(test_csv['MDM'] == True) & (test_csv['ctp_correct'] == False)]\n",
|
||
|
" \n",
|
||
|
" # Save false positives to a separate file\n",
|
||
|
" fp_output_path = f'0.class_document/{group_number}/fp_class.csv'\n",
|
||
|
" false_positive_rows.to_csv(fp_output_path, index=False, encoding='utf-8-sig')\n",
|
||
|
"\n",
|
||
|
"# Calculate and print the average accuracy across all groups\n",
|
||
|
"average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||
|
"print(f\"Average Accuracy (MDM=True) across all groups: {average_accuracy:.2f}%\")\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "torch",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.14"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|