2024-08-26 19:51:11 +09:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-09-25 08:52:30 +09:00
|
|
|
"execution_count": 5,
|
2024-08-26 19:51:11 +09:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
2024-09-25 08:52:30 +09:00
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Processing group 1...\n",
|
|
|
|
"Total updates where p_correct is False and ctp_correct is True (group 1): 55\n",
|
|
|
|
"Number of rows with duplicates in the same ships_idx (group 1): 34\n",
|
|
|
|
"Number of rows without duplicates in the same ships_idx (group 1): 21\n",
|
|
|
|
"Number of updates made (group 1): 427\n",
|
|
|
|
"Updated test CSV saved to 0.class_document/distilbert/1/test_p_c_r.csv\n",
|
|
|
|
"Refine CSV saved to 0.class_document/distilbert/1/refine.csv\n",
|
|
|
|
"Processing group 2...\n",
|
|
|
|
"Total updates where p_correct is False and ctp_correct is True (group 2): 63\n",
|
|
|
|
"Number of rows with duplicates in the same ships_idx (group 2): 21\n",
|
|
|
|
"Number of rows without duplicates in the same ships_idx (group 2): 42\n",
|
|
|
|
"Number of updates made (group 2): 225\n",
|
|
|
|
"Updated test CSV saved to 0.class_document/distilbert/2/test_p_c_r.csv\n",
|
|
|
|
"Refine CSV saved to 0.class_document/distilbert/2/refine.csv\n",
|
|
|
|
"Processing group 3...\n",
|
|
|
|
"Total updates where p_correct is False and ctp_correct is True (group 3): 32\n",
|
|
|
|
"Number of rows with duplicates in the same ships_idx (group 3): 10\n",
|
|
|
|
"Number of rows without duplicates in the same ships_idx (group 3): 22\n",
|
|
|
|
"Number of updates made (group 3): 343\n",
|
|
|
|
"Updated test CSV saved to 0.class_document/distilbert/3/test_p_c_r.csv\n",
|
|
|
|
"Refine CSV saved to 0.class_document/distilbert/3/refine.csv\n",
|
|
|
|
"Processing group 4...\n",
|
|
|
|
"Total updates where p_correct is False and ctp_correct is True (group 4): 37\n",
|
|
|
|
"Number of rows with duplicates in the same ships_idx (group 4): 25\n",
|
|
|
|
"Number of rows without duplicates in the same ships_idx (group 4): 12\n",
|
|
|
|
"Number of updates made (group 4): 596\n",
|
|
|
|
"Updated test CSV saved to 0.class_document/distilbert/4/test_p_c_r.csv\n",
|
|
|
|
"Refine CSV saved to 0.class_document/distilbert/4/refine.csv\n",
|
|
|
|
"Processing group 5...\n",
|
|
|
|
"Total updates where p_correct is False and ctp_correct is True (group 5): 40\n",
|
|
|
|
"Number of rows with duplicates in the same ships_idx (group 5): 19\n",
|
|
|
|
"Number of rows without duplicates in the same ships_idx (group 5): 21\n",
|
|
|
|
"Number of updates made (group 5): 379\n",
|
|
|
|
"Updated test CSV saved to 0.class_document/distilbert/5/test_p_c_r.csv\n",
|
|
|
|
"Refine CSV saved to 0.class_document/distilbert/5/refine.csv\n"
|
2024-08-26 19:51:11 +09:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"import pandas as pd\n",
|
|
|
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
|
|
|
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
|
|
|
"from tqdm import tqdm\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
"import re\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
"model = \"distilbert\"\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
"for group_number in range(1, 6): # Group 1 to 5\n",
|
|
|
|
" print(f\"Processing group {group_number}...\")\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
" # Load test CSV for the current group\n",
|
|
|
|
" test_path = f'0.class_document/{model}/t5-tiny/{group_number}/test_p_c.csv'\n",
|
|
|
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
" # Initialize counters\n",
|
|
|
|
" update_count = 0\n",
|
|
|
|
" duplicate_count = 0\n",
|
|
|
|
" non_duplicate_count = 0\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
" # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
|
|
|
" for index, row in test_csv.iterrows():\n",
|
|
|
|
" if not row['p_correct'] and row['ctp_correct']:\n",
|
|
|
|
" update_count += 1 # Increment the counter\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
" # Check for duplicates within the same ships_idx\n",
|
|
|
|
" same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n",
|
|
|
|
" (test_csv['p_thing'] == row['c_thing']) &\n",
|
|
|
|
" (test_csv['p_property'] == row['c_property'])]\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
" if len(same_idx_rows) > 0:\n",
|
|
|
|
" duplicate_count += 1\n",
|
|
|
|
" else:\n",
|
|
|
|
" non_duplicate_count += 1\n",
|
2024-08-26 19:51:11 +09:00
|
|
|
"\n",
|
2024-09-25 08:52:30 +09:00
|
|
|
" # Print the results for the current group\n",
|
|
|
|
" print(f\"Total updates where p_correct is False and ctp_correct is True (group {group_number}): {update_count}\")\n",
|
|
|
|
" print(f\"Number of rows with duplicates in the same ships_idx (group {group_number}): {duplicate_count}\")\n",
|
|
|
|
" print(f\"Number of rows without duplicates in the same ships_idx (group {group_number}): {non_duplicate_count}\")\n",
|
|
|
|
"\n",
|
|
|
|
" # Initialize a list to hold rows that meet the conditions for refinement\n",
|
|
|
|
" refine_rows = []\n",
|
|
|
|
" update_count = 0\n",
|
|
|
|
"\n",
|
|
|
|
" # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
|
|
|
" for index, row in test_csv.iterrows():\n",
|
|
|
|
" if (not row['p_MDM'] and row['c_score'] >= 0.91 and \n",
|
|
|
|
" (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n",
|
|
|
|
"\n",
|
|
|
|
" test_csv.at[index, 'p_thing'] = row['c_thing']\n",
|
|
|
|
" test_csv.at[index, 'p_property'] = row['c_property']\n",
|
|
|
|
" test_csv.at[index, 'p_MDM'] = True\n",
|
|
|
|
"\n",
|
|
|
|
" updated_p_thing = test_csv.at[index, 'p_thing']\n",
|
|
|
|
" updated_p_property = test_csv.at[index, 'p_property']\n",
|
|
|
|
" p_pattern = re.sub(r'\\d', '#', updated_p_thing) + \" \" + re.sub(r'\\d', '#', updated_p_property)\n",
|
|
|
|
" test_csv.at[index, 'p_pattern'] = p_pattern\n",
|
|
|
|
" update_count += 1 # Increment the counter\n",
|
|
|
|
" refine_rows.append(row) # Add the row to the refine list\n",
|
|
|
|
"\n",
|
|
|
|
" # Convert the list of refine rows into a DataFrame\n",
|
|
|
|
" refine_df = pd.DataFrame(refine_rows)\n",
|
|
|
|
"\n",
|
|
|
|
" # Save the refine DataFrame to a CSV file for the current group\n",
|
|
|
|
" refine_output_path = f'0.class_document/{model}/{group_number}/refine.csv'\n",
|
|
|
|
" refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n",
|
|
|
|
"\n",
|
|
|
|
" # Print the number of updates made\n",
|
|
|
|
" print(f\"Number of updates made (group {group_number}): {update_count}\")\n",
|
|
|
|
"\n",
|
|
|
|
" # Save the updated test CSV for the current group\n",
|
|
|
|
" output_file_path = f'0.class_document/{model}/{group_number}/test_p_c_r.csv'\n",
|
|
|
|
" test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
|
|
|
"\n",
|
|
|
|
" print(f\"Updated test CSV saved to {output_file_path}\")\n",
|
|
|
|
" print(f\"Refine CSV saved to {refine_output_path}\")\n"
|
2024-08-26 19:51:11 +09:00
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "torch",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.10.14"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 2
|
|
|
|
}
|