152 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
			
		
		
	
	
			152 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
{
 | 
						|
 "cells": [
 | 
						|
  {
 | 
						|
   "cell_type": "code",
 | 
						|
   "execution_count": 5,
 | 
						|
   "metadata": {},
 | 
						|
   "outputs": [
 | 
						|
    {
 | 
						|
     "name": "stdout",
 | 
						|
     "output_type": "stream",
 | 
						|
     "text": [
 | 
						|
      "Processing group 1...\n",
 | 
						|
      "Total updates where p_correct is False and ctp_correct is True (group 1): 55\n",
 | 
						|
      "Number of rows with duplicates in the same ships_idx (group 1): 34\n",
 | 
						|
      "Number of rows without duplicates in the same ships_idx (group 1): 21\n",
 | 
						|
      "Number of updates made (group 1): 427\n",
 | 
						|
      "Updated test CSV saved to 0.class_document/distilbert/1/test_p_c_r.csv\n",
 | 
						|
      "Refine CSV saved to 0.class_document/distilbert/1/refine.csv\n",
 | 
						|
      "Processing group 2...\n",
 | 
						|
      "Total updates where p_correct is False and ctp_correct is True (group 2): 63\n",
 | 
						|
      "Number of rows with duplicates in the same ships_idx (group 2): 21\n",
 | 
						|
      "Number of rows without duplicates in the same ships_idx (group 2): 42\n",
 | 
						|
      "Number of updates made (group 2): 225\n",
 | 
						|
      "Updated test CSV saved to 0.class_document/distilbert/2/test_p_c_r.csv\n",
 | 
						|
      "Refine CSV saved to 0.class_document/distilbert/2/refine.csv\n",
 | 
						|
      "Processing group 3...\n",
 | 
						|
      "Total updates where p_correct is False and ctp_correct is True (group 3): 32\n",
 | 
						|
      "Number of rows with duplicates in the same ships_idx (group 3): 10\n",
 | 
						|
      "Number of rows without duplicates in the same ships_idx (group 3): 22\n",
 | 
						|
      "Number of updates made (group 3): 343\n",
 | 
						|
      "Updated test CSV saved to 0.class_document/distilbert/3/test_p_c_r.csv\n",
 | 
						|
      "Refine CSV saved to 0.class_document/distilbert/3/refine.csv\n",
 | 
						|
      "Processing group 4...\n",
 | 
						|
      "Total updates where p_correct is False and ctp_correct is True (group 4): 37\n",
 | 
						|
      "Number of rows with duplicates in the same ships_idx (group 4): 25\n",
 | 
						|
      "Number of rows without duplicates in the same ships_idx (group 4): 12\n",
 | 
						|
      "Number of updates made (group 4): 596\n",
 | 
						|
      "Updated test CSV saved to 0.class_document/distilbert/4/test_p_c_r.csv\n",
 | 
						|
      "Refine CSV saved to 0.class_document/distilbert/4/refine.csv\n",
 | 
						|
      "Processing group 5...\n",
 | 
						|
      "Total updates where p_correct is False and ctp_correct is True (group 5): 40\n",
 | 
						|
      "Number of rows with duplicates in the same ships_idx (group 5): 19\n",
 | 
						|
      "Number of rows without duplicates in the same ships_idx (group 5): 21\n",
 | 
						|
      "Number of updates made (group 5): 379\n",
 | 
						|
      "Updated test CSV saved to 0.class_document/distilbert/5/test_p_c_r.csv\n",
 | 
						|
      "Refine CSV saved to 0.class_document/distilbert/5/refine.csv\n"
 | 
						|
     ]
 | 
						|
    }
 | 
						|
   ],
 | 
						|
   "source": [
 | 
						|
    "import pandas as pd\n",
 | 
						|
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
 | 
						|
    "from sklearn.metrics.pairwise import cosine_similarity\n",
 | 
						|
    "from tqdm import tqdm\n",
 | 
						|
    "import re\n",
 | 
						|
    "\n",
 | 
						|
    "model = \"distilbert\"\n",
 | 
						|
    "\n",
 | 
						|
    "for group_number in range(1, 6):  # Group 1 to 5\n",
 | 
						|
    "    print(f\"Processing group {group_number}...\")\n",
 | 
						|
    "\n",
 | 
						|
    "    # Load test CSV for the current group\n",
 | 
						|
    "    test_path = f'0.class_document/{model}/t5-tiny/{group_number}/test_p_c.csv'\n",
 | 
						|
    "    test_csv = pd.read_csv(test_path, low_memory=False)\n",
 | 
						|
    "\n",
 | 
						|
    "    # Initialize counters\n",
 | 
						|
    "    update_count = 0\n",
 | 
						|
    "    duplicate_count = 0\n",
 | 
						|
    "    non_duplicate_count = 0\n",
 | 
						|
    "\n",
 | 
						|
    "    # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
 | 
						|
    "    for index, row in test_csv.iterrows():\n",
 | 
						|
    "        if not row['p_correct'] and row['ctp_correct']:\n",
 | 
						|
    "            update_count += 1  # Increment the counter\n",
 | 
						|
    "\n",
 | 
						|
    "            # Check for duplicates within the same ships_idx\n",
 | 
						|
    "            same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n",
 | 
						|
    "                                     (test_csv['p_thing'] == row['c_thing']) &\n",
 | 
						|
    "                                     (test_csv['p_property'] == row['c_property'])]\n",
 | 
						|
    "\n",
 | 
						|
    "            if len(same_idx_rows) > 0:\n",
 | 
						|
    "                duplicate_count += 1\n",
 | 
						|
    "            else:\n",
 | 
						|
    "                non_duplicate_count += 1\n",
 | 
						|
    "\n",
 | 
						|
    "    # Print the results for the current group\n",
 | 
						|
    "    print(f\"Total updates where p_correct is False and ctp_correct is True (group {group_number}): {update_count}\")\n",
 | 
						|
    "    print(f\"Number of rows with duplicates in the same ships_idx (group {group_number}): {duplicate_count}\")\n",
 | 
						|
    "    print(f\"Number of rows without duplicates in the same ships_idx (group {group_number}): {non_duplicate_count}\")\n",
 | 
						|
    "\n",
 | 
						|
    "    # Initialize a list to hold rows that meet the conditions for refinement\n",
 | 
						|
    "    refine_rows = []\n",
 | 
						|
    "    update_count = 0\n",
 | 
						|
    "\n",
 | 
						|
    "    # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
 | 
						|
    "    for index, row in test_csv.iterrows():\n",
 | 
						|
    "        if (not row['p_MDM'] and row['c_score'] >= 0.91 and \n",
 | 
						|
    "            (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n",
 | 
						|
    "\n",
 | 
						|
    "            test_csv.at[index, 'p_thing'] = row['c_thing']\n",
 | 
						|
    "            test_csv.at[index, 'p_property'] = row['c_property']\n",
 | 
						|
    "            test_csv.at[index, 'p_MDM'] = True\n",
 | 
						|
    "\n",
 | 
						|
    "            updated_p_thing = test_csv.at[index, 'p_thing']\n",
 | 
						|
    "            updated_p_property = test_csv.at[index, 'p_property']\n",
 | 
						|
    "            p_pattern = re.sub(r'\\d', '#', updated_p_thing) + \" \" + re.sub(r'\\d', '#', updated_p_property)\n",
 | 
						|
    "            test_csv.at[index, 'p_pattern'] = p_pattern\n",
 | 
						|
    "            update_count += 1  # Increment the counter\n",
 | 
						|
    "            refine_rows.append(row)  # Add the row to the refine list\n",
 | 
						|
    "\n",
 | 
						|
    "    # Convert the list of refine rows into a DataFrame\n",
 | 
						|
    "    refine_df = pd.DataFrame(refine_rows)\n",
 | 
						|
    "\n",
 | 
						|
    "    # Save the refine DataFrame to a CSV file for the current group\n",
 | 
						|
    "    refine_output_path = f'0.class_document/{model}/{group_number}/refine.csv'\n",
 | 
						|
    "    refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n",
 | 
						|
    "\n",
 | 
						|
    "    # Print the number of updates made\n",
 | 
						|
    "    print(f\"Number of updates made (group {group_number}): {update_count}\")\n",
 | 
						|
    "\n",
 | 
						|
    "    # Save the updated test CSV for the current group\n",
 | 
						|
    "    output_file_path = f'0.class_document/{model}/{group_number}/test_p_c_r.csv'\n",
 | 
						|
    "    test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
 | 
						|
    "\n",
 | 
						|
    "    print(f\"Updated test CSV saved to {output_file_path}\")\n",
 | 
						|
    "    print(f\"Refine CSV saved to {refine_output_path}\")\n"
 | 
						|
   ]
 | 
						|
  }
 | 
						|
 ],
 | 
						|
 "metadata": {
 | 
						|
  "kernelspec": {
 | 
						|
   "display_name": "torch",
 | 
						|
   "language": "python",
 | 
						|
   "name": "python3"
 | 
						|
  },
 | 
						|
  "language_info": {
 | 
						|
   "codemirror_mode": {
 | 
						|
    "name": "ipython",
 | 
						|
    "version": 3
 | 
						|
   },
 | 
						|
   "file_extension": ".py",
 | 
						|
   "mimetype": "text/x-python",
 | 
						|
   "name": "python",
 | 
						|
   "nbconvert_exporter": "python",
 | 
						|
   "pygments_lexer": "ipython3",
 | 
						|
   "version": "3.10.14"
 | 
						|
  }
 | 
						|
 },
 | 
						|
 "nbformat": 4,
 | 
						|
 "nbformat_minor": 2
 | 
						|
}
 |