hipom_data_mapping/post_process/tfidf_class/3.refine.ipynb

152 lines
7.0 KiB
Plaintext
Raw Normal View History

2024-08-26 19:51:11 +09:00
{
"cells": [
{
"cell_type": "code",
2024-09-25 08:52:30 +09:00
"execution_count": 5,
2024-08-26 19:51:11 +09:00
"metadata": {},
"outputs": [
{
2024-09-25 08:52:30 +09:00
"name": "stdout",
"output_type": "stream",
"text": [
"Processing group 1...\n",
"Total updates where p_correct is False and ctp_correct is True (group 1): 55\n",
"Number of rows with duplicates in the same ships_idx (group 1): 34\n",
"Number of rows without duplicates in the same ships_idx (group 1): 21\n",
"Number of updates made (group 1): 427\n",
"Updated test CSV saved to 0.class_document/distilbert/1/test_p_c_r.csv\n",
"Refine CSV saved to 0.class_document/distilbert/1/refine.csv\n",
"Processing group 2...\n",
"Total updates where p_correct is False and ctp_correct is True (group 2): 63\n",
"Number of rows with duplicates in the same ships_idx (group 2): 21\n",
"Number of rows without duplicates in the same ships_idx (group 2): 42\n",
"Number of updates made (group 2): 225\n",
"Updated test CSV saved to 0.class_document/distilbert/2/test_p_c_r.csv\n",
"Refine CSV saved to 0.class_document/distilbert/2/refine.csv\n",
"Processing group 3...\n",
"Total updates where p_correct is False and ctp_correct is True (group 3): 32\n",
"Number of rows with duplicates in the same ships_idx (group 3): 10\n",
"Number of rows without duplicates in the same ships_idx (group 3): 22\n",
"Number of updates made (group 3): 343\n",
"Updated test CSV saved to 0.class_document/distilbert/3/test_p_c_r.csv\n",
"Refine CSV saved to 0.class_document/distilbert/3/refine.csv\n",
"Processing group 4...\n",
"Total updates where p_correct is False and ctp_correct is True (group 4): 37\n",
"Number of rows with duplicates in the same ships_idx (group 4): 25\n",
"Number of rows without duplicates in the same ships_idx (group 4): 12\n",
"Number of updates made (group 4): 596\n",
"Updated test CSV saved to 0.class_document/distilbert/4/test_p_c_r.csv\n",
"Refine CSV saved to 0.class_document/distilbert/4/refine.csv\n",
"Processing group 5...\n",
"Total updates where p_correct is False and ctp_correct is True (group 5): 40\n",
"Number of rows with duplicates in the same ships_idx (group 5): 19\n",
"Number of rows without duplicates in the same ships_idx (group 5): 21\n",
"Number of updates made (group 5): 379\n",
"Updated test CSV saved to 0.class_document/distilbert/5/test_p_c_r.csv\n",
"Refine CSV saved to 0.class_document/distilbert/5/refine.csv\n"
2024-08-26 19:51:11 +09:00
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from tqdm import tqdm\n",
2024-09-25 08:52:30 +09:00
"import re\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"model = \"distilbert\"\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"for group_number in range(1, 6): # Group 1 to 5\n",
" print(f\"Processing group {group_number}...\")\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" # Load test CSV for the current group\n",
" test_path = f'0.class_document/{model}/t5-tiny/{group_number}/test_p_c.csv'\n",
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" # Initialize counters\n",
" update_count = 0\n",
" duplicate_count = 0\n",
" non_duplicate_count = 0\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
" for index, row in test_csv.iterrows():\n",
" if not row['p_correct'] and row['ctp_correct']:\n",
" update_count += 1 # Increment the counter\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" # Check for duplicates within the same ships_idx\n",
" same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n",
" (test_csv['p_thing'] == row['c_thing']) &\n",
" (test_csv['p_property'] == row['c_property'])]\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" if len(same_idx_rows) > 0:\n",
" duplicate_count += 1\n",
" else:\n",
" non_duplicate_count += 1\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
" # Print the results for the current group\n",
" print(f\"Total updates where p_correct is False and ctp_correct is True (group {group_number}): {update_count}\")\n",
" print(f\"Number of rows with duplicates in the same ships_idx (group {group_number}): {duplicate_count}\")\n",
" print(f\"Number of rows without duplicates in the same ships_idx (group {group_number}): {non_duplicate_count}\")\n",
"\n",
" # Initialize a list to hold rows that meet the conditions for refinement\n",
" refine_rows = []\n",
" update_count = 0\n",
"\n",
" # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
" for index, row in test_csv.iterrows():\n",
" if (not row['p_MDM'] and row['c_score'] >= 0.91 and \n",
" (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n",
"\n",
" test_csv.at[index, 'p_thing'] = row['c_thing']\n",
" test_csv.at[index, 'p_property'] = row['c_property']\n",
" test_csv.at[index, 'p_MDM'] = True\n",
"\n",
" updated_p_thing = test_csv.at[index, 'p_thing']\n",
" updated_p_property = test_csv.at[index, 'p_property']\n",
" p_pattern = re.sub(r'\\d', '#', updated_p_thing) + \" \" + re.sub(r'\\d', '#', updated_p_property)\n",
" test_csv.at[index, 'p_pattern'] = p_pattern\n",
" update_count += 1 # Increment the counter\n",
" refine_rows.append(row) # Add the row to the refine list\n",
"\n",
" # Convert the list of refine rows into a DataFrame\n",
" refine_df = pd.DataFrame(refine_rows)\n",
"\n",
" # Save the refine DataFrame to a CSV file for the current group\n",
" refine_output_path = f'0.class_document/{model}/{group_number}/refine.csv'\n",
" refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n",
"\n",
" # Print the number of updates made\n",
" print(f\"Number of updates made (group {group_number}): {update_count}\")\n",
"\n",
" # Save the updated test CSV for the current group\n",
" output_file_path = f'0.class_document/{model}/{group_number}/test_p_c_r.csv'\n",
" test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
"\n",
" print(f\"Updated test CSV saved to {output_file_path}\")\n",
" print(f\"Refine CSV saved to {refine_output_path}\")\n"
2024-08-26 19:51:11 +09:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}