{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing group 1...\n", "Total updates where p_correct is False and ctp_correct is True (group 1): 55\n", "Number of rows with duplicates in the same ships_idx (group 1): 34\n", "Number of rows without duplicates in the same ships_idx (group 1): 21\n", "Number of updates made (group 1): 427\n", "Updated test CSV saved to 0.class_document/distilbert/1/test_p_c_r.csv\n", "Refine CSV saved to 0.class_document/distilbert/1/refine.csv\n", "Processing group 2...\n", "Total updates where p_correct is False and ctp_correct is True (group 2): 63\n", "Number of rows with duplicates in the same ships_idx (group 2): 21\n", "Number of rows without duplicates in the same ships_idx (group 2): 42\n", "Number of updates made (group 2): 225\n", "Updated test CSV saved to 0.class_document/distilbert/2/test_p_c_r.csv\n", "Refine CSV saved to 0.class_document/distilbert/2/refine.csv\n", "Processing group 3...\n", "Total updates where p_correct is False and ctp_correct is True (group 3): 32\n", "Number of rows with duplicates in the same ships_idx (group 3): 10\n", "Number of rows without duplicates in the same ships_idx (group 3): 22\n", "Number of updates made (group 3): 343\n", "Updated test CSV saved to 0.class_document/distilbert/3/test_p_c_r.csv\n", "Refine CSV saved to 0.class_document/distilbert/3/refine.csv\n", "Processing group 4...\n", "Total updates where p_correct is False and ctp_correct is True (group 4): 37\n", "Number of rows with duplicates in the same ships_idx (group 4): 25\n", "Number of rows without duplicates in the same ships_idx (group 4): 12\n", "Number of updates made (group 4): 596\n", "Updated test CSV saved to 0.class_document/distilbert/4/test_p_c_r.csv\n", "Refine CSV saved to 0.class_document/distilbert/4/refine.csv\n", "Processing group 5...\n", "Total updates where p_correct is False and ctp_correct is True (group 5): 40\n", "Number of rows with duplicates in the same ships_idx (group 5): 19\n", "Number of rows without duplicates in the same ships_idx (group 5): 21\n", "Number of updates made (group 5): 379\n", "Updated test CSV saved to 0.class_document/distilbert/5/test_p_c_r.csv\n", "Refine CSV saved to 0.class_document/distilbert/5/refine.csv\n" ] } ], "source": [ "import pandas as pd\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "from tqdm import tqdm\n", "import re\n", "\n", "model = \"distilbert\"\n", "\n", "for group_number in range(1, 6): # Group 1 to 5\n", " print(f\"Processing group {group_number}...\")\n", "\n", " # Load test CSV for the current group\n", " test_path = f'0.class_document/{model}/t5-tiny/{group_number}/test_p_c.csv'\n", " test_csv = pd.read_csv(test_path, low_memory=False)\n", "\n", " # Initialize counters\n", " update_count = 0\n", " duplicate_count = 0\n", " non_duplicate_count = 0\n", "\n", " # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n", " for index, row in test_csv.iterrows():\n", " if not row['p_correct'] and row['ctp_correct']:\n", " update_count += 1 # Increment the counter\n", "\n", " # Check for duplicates within the same ships_idx\n", " same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n", " (test_csv['p_thing'] == row['c_thing']) &\n", " (test_csv['p_property'] == row['c_property'])]\n", "\n", " if len(same_idx_rows) > 0:\n", " duplicate_count += 1\n", " else:\n", " non_duplicate_count += 1\n", "\n", " # Print the results for the current group\n", " print(f\"Total updates where p_correct is False and ctp_correct is True (group {group_number}): {update_count}\")\n", " print(f\"Number of rows with duplicates in the same ships_idx (group {group_number}): {duplicate_count}\")\n", " print(f\"Number of rows without duplicates in the same ships_idx (group {group_number}): {non_duplicate_count}\")\n", "\n", " # Initialize a list to hold rows that meet the conditions for refinement\n", " refine_rows = []\n", " update_count = 0\n", "\n", " # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n", " for index, row in test_csv.iterrows():\n", " if (not row['p_MDM'] and row['c_score'] >= 0.91 and \n", " (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n", "\n", " test_csv.at[index, 'p_thing'] = row['c_thing']\n", " test_csv.at[index, 'p_property'] = row['c_property']\n", " test_csv.at[index, 'p_MDM'] = True\n", "\n", " updated_p_thing = test_csv.at[index, 'p_thing']\n", " updated_p_property = test_csv.at[index, 'p_property']\n", " p_pattern = re.sub(r'\\d', '#', updated_p_thing) + \" \" + re.sub(r'\\d', '#', updated_p_property)\n", " test_csv.at[index, 'p_pattern'] = p_pattern\n", " update_count += 1 # Increment the counter\n", " refine_rows.append(row) # Add the row to the refine list\n", "\n", " # Convert the list of refine rows into a DataFrame\n", " refine_df = pd.DataFrame(refine_rows)\n", "\n", " # Save the refine DataFrame to a CSV file for the current group\n", " refine_output_path = f'0.class_document/{model}/{group_number}/refine.csv'\n", " refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n", "\n", " # Print the number of updates made\n", " print(f\"Number of updates made (group {group_number}): {update_count}\")\n", "\n", " # Save the updated test CSV for the current group\n", " output_file_path = f'0.class_document/{model}/{group_number}/test_p_c_r.csv'\n", " test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n", "\n", " print(f\"Updated test CSV saved to {output_file_path}\")\n", " print(f\"Refine CSV saved to {refine_output_path}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "torch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }