[TASK] the entier paper work
This commit is contained in:
parent
3d2266cf65
commit
24829c7abf
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -2,7 +2,7 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -17,15 +17,15 @@
|
||||||
"Changes made in ships_idx 1005: 18\n",
|
"Changes made in ships_idx 1005: 18\n",
|
||||||
"Changes made in ships_idx 1008: 22\n",
|
"Changes made in ships_idx 1008: 22\n",
|
||||||
"Changes made in ships_idx 1009: 5\n",
|
"Changes made in ships_idx 1009: 5\n",
|
||||||
"Changes made in ships_idx 1010: 135\n",
|
"Changes made in ships_idx 1010: 131\n",
|
||||||
"Changes made in ships_idx 1011: 46\n",
|
"Changes made in ships_idx 1011: 46\n",
|
||||||
"Changes made in ships_idx 1012: 2\n",
|
"Changes made in ships_idx 1012: 2\n",
|
||||||
"Changes made in ships_idx 1013: 130\n",
|
"Changes made in ships_idx 1013: 130\n",
|
||||||
"Changes made in ships_idx 1014: 46\n",
|
"Changes made in ships_idx 1014: 46\n",
|
||||||
"Changes made in ships_idx 1015: 147\n",
|
"Changes made in ships_idx 1015: 145\n",
|
||||||
"Changes made in ships_idx 1016: 191\n",
|
"Changes made in ships_idx 1016: 191\n",
|
||||||
"Changes made in ships_idx 1017: 111\n",
|
"Changes made in ships_idx 1017: 111\n",
|
||||||
"Changes made in ships_idx 1018: 682\n",
|
"Changes made in ships_idx 1018: 680\n",
|
||||||
"Changes made in ships_idx 1019: 2\n",
|
"Changes made in ships_idx 1019: 2\n",
|
||||||
"Changes made in ships_idx 1020: 10\n",
|
"Changes made in ships_idx 1020: 10\n",
|
||||||
"Changes made in ships_idx 1021: 2\n",
|
"Changes made in ships_idx 1021: 2\n",
|
||||||
|
@ -42,21 +42,21 @@
|
||||||
"Changes made in ships_idx 1032: 225\n",
|
"Changes made in ships_idx 1032: 225\n",
|
||||||
"Changes made in ships_idx 1033: 147\n",
|
"Changes made in ships_idx 1033: 147\n",
|
||||||
"Changes made in ships_idx 1035: 132\n",
|
"Changes made in ships_idx 1035: 132\n",
|
||||||
"Changes made in ships_idx 1036: 12\n",
|
"Changes made in ships_idx 1036: 5\n",
|
||||||
"Changes made in ships_idx 1037: 3\n",
|
"Changes made in ships_idx 1037: 3\n",
|
||||||
"Changes made in ships_idx 1038: 8\n",
|
"Changes made in ships_idx 1038: 6\n",
|
||||||
"Changes made in ships_idx 1039: 232\n",
|
"Changes made in ships_idx 1039: 232\n",
|
||||||
"Changes made in ships_idx 1042: 20\n",
|
"Changes made in ships_idx 1042: 20\n",
|
||||||
"Changes made in ships_idx 1043: 154\n",
|
"Changes made in ships_idx 1043: 154\n",
|
||||||
"Changes made in ships_idx 1044: 121\n",
|
"Changes made in ships_idx 1044: 117\n",
|
||||||
"Changes made in ships_idx 1045: 255\n",
|
"Changes made in ships_idx 1045: 243\n",
|
||||||
"Changes made in ships_idx 1046: 6\n",
|
"Changes made in ships_idx 1046: 6\n",
|
||||||
"Changes made in ships_idx 1047: 12\n",
|
"Changes made in ships_idx 1047: 12\n",
|
||||||
"Changes made in ships_idx 1048: 82\n",
|
"Changes made in ships_idx 1048: 82\n",
|
||||||
"Changes made in ships_idx 1049: 912\n",
|
"Changes made in ships_idx 1049: 912\n",
|
||||||
"Changes made in ships_idx 1050: 46\n",
|
"Changes made in ships_idx 1050: 46\n",
|
||||||
"Changes made in ships_idx 1051: 63\n",
|
"Changes made in ships_idx 1051: 57\n",
|
||||||
"Total number of changes made: 4951\n",
|
"Total number of changes made: 4912\n",
|
||||||
"Updated data saved to raw_data_add_tag.csv\n"
|
"Updated data saved to raw_data_add_tag.csv\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the data_mapping CSV file\n",
|
"# Load the data_mapping CSV file\n",
|
||||||
"data_mapping_file_path = '../../data_import/raw_data.csv' # Adjust this path to your actual file location\n",
|
"data_mapping_file_path = '../../data_import/raw_data.csv' # Adjust this path to your actual file location\n",
|
||||||
"# data_mapping_file_path = 'raw_data_add_tag.csv' # Adjust this path to your actual file location\n",
|
"data_mapping_file_path = 'raw_data_add_tag.csv' # Adjust this path to your actual file location\n",
|
||||||
"data_mapping = pd.read_csv(data_mapping_file_path, dtype=str)\n",
|
"data_mapping = pd.read_csv(data_mapping_file_path, dtype=str)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Backup the original tag_description\n",
|
"# Backup the original tag_description\n",
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -10,11 +10,11 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Final Group Allocation:\n",
|
"Final Group Allocation:\n",
|
||||||
"Group 1: Ships_idx = [1003, 1028, 1049, 1044, 1020, 1041, 1045, 1036, 1005, 1006], PD type = 537, PD = 2006, SD = 14719\n",
|
"Group 1: Ships_idx = [1025, 1032, 1042, 1046, 1023, 1037, 1024, 1014, 1019, 1008], PD type = 529, PD = 1992, SD = 9855\n",
|
||||||
"Group 2: Ships_idx = [1025, 1035, 1021, 1026, 1002, 1030, 1024, 1037, 1038, 1029], PD type = 537, PD = 1958, SD = 8173\n",
|
"Group 2: Ships_idx = [1003, 1028, 1018, 1020, 1033, 1050, 1030, 1051, 1004, 1036], PD type = 528, PD = 2113, SD = 13074\n",
|
||||||
"Group 3: Ships_idx = [1016, 1046, 1031, 1009, 1048, 1043, 1042, 1019, 1018, 1007, 1000], PD type = 534, PD = 2079, SD = 15310\n",
|
"Group 3: Ships_idx = [1016, 1026, 1043, 1031, 1012, 1021, 1000, 1011, 1006, 1005, 1038], PD type = 521, PD = 2140, SD = 10722\n",
|
||||||
"Group 4: Ships_idx = [1004, 1032, 1039, 1014, 1040, 1017, 1022, 1051, 1008, 1050, 1013], PD type = 532, PD = 2066, SD = 12882\n",
|
"Group 4: Ships_idx = [1047, 1049, 1010, 1027, 1013, 1022, 1048, 1017, 1045, 1007], PD type = 521, PD = 2102, SD = 15451\n",
|
||||||
"Group 5: Ships_idx = [1047, 1015, 1027, 1010, 1011, 1001, 1034, 1023, 1012, 1033], PD type = 531, PD = 2064, SD = 10988\n"
|
"Group 5: Ships_idx = [1039, 1035, 1044, 1009, 1015, 1040, 1001, 1034, 1041, 1002, 1029], PD type = 500, PD = 2183, SD = 12969\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -259,7 +259,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -348,7 +348,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -2,74 +2,118 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Performance for all_with_p_s.csv:\n",
|
"Performance for group 1 (test_s.csv):\n",
|
||||||
"TP: 1724, TN: 11907, FP: 919, FN: 272\n",
|
"TP: 1794, TN: 9954, FP: 1005, FN: 319\n",
|
||||||
"Precision: 0.6523, Recall: 0.8637, Accuracy: 0.9196\n"
|
"Precision: 0.6409, Recall: 0.8490, Accuracy: 0.8987, F1-Score: 0.7305\n",
|
||||||
|
"--------------------------------------------------\n",
|
||||||
|
"Performance for group 2 (test_s.csv):\n",
|
||||||
|
"TP: 1824, TN: 7716, FP: 866, FN: 316\n",
|
||||||
|
"Precision: 0.6781, Recall: 0.8523, Accuracy: 0.8898, F1-Score: 0.7553\n",
|
||||||
|
"--------------------------------------------------\n",
|
||||||
|
"Performance for group 3 (test_s.csv):\n",
|
||||||
|
"TP: 1804, TN: 6866, FP: 996, FN: 188\n",
|
||||||
|
"Precision: 0.6443, Recall: 0.9056, Accuracy: 0.8798, F1-Score: 0.7529\n",
|
||||||
|
"--------------------------------------------------\n",
|
||||||
|
"Performance for group 4 (test_s.csv):\n",
|
||||||
|
"TP: 1916, TN: 12360, FP: 989, FN: 186\n",
|
||||||
|
"Precision: 0.6596, Recall: 0.9115, Accuracy: 0.9240, F1-Score: 0.7653\n",
|
||||||
|
"--------------------------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/tmp/ipykernel_2997916/1903646223.py:38: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value '' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n",
|
||||||
|
" test_s_csv.fillna('', inplace=True)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Performance for group 5 (test_s.csv):\n",
|
||||||
|
"TP: 1910, TN: 9800, FP: 955, FN: 273\n",
|
||||||
|
"Precision: 0.6667, Recall: 0.8749, Accuracy: 0.9051, F1-Score: 0.7567\n",
|
||||||
|
"--------------------------------------------------\n",
|
||||||
|
"Average performance across all groups:\n",
|
||||||
|
"Average Precision: 0.6579\n",
|
||||||
|
"Average Recall: 0.8787\n",
|
||||||
|
"Average Accuracy: 0.8995\n",
|
||||||
|
"Average F1-Score: 0.7521\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Set the group number\n",
|
|
||||||
"group_number = 1 # Change this to the desired group number\n",
|
|
||||||
"\n",
|
|
||||||
"# File paths for the two datasets\n",
|
|
||||||
"test_s_path = f'../post_process/0.result/{group_number}/test_s.csv'\n",
|
|
||||||
"\n",
|
|
||||||
"# Load the CSV files\n",
|
|
||||||
"test_s_csv = pd.read_csv(test_s_path, low_memory=False)\n",
|
|
||||||
"test_s_csv.fillna('', inplace=True)\n",
|
|
||||||
"\n",
|
|
||||||
"def evaluate_performance(test_csv):\n",
|
"def evaluate_performance(test_csv):\n",
|
||||||
" # Initialize counters for TP, TN, FP, FN\n",
|
|
||||||
" TP = 0\n",
|
" TP = 0\n",
|
||||||
" TN = 0\n",
|
" TN = 0\n",
|
||||||
" FP = 0\n",
|
" FP = 0\n",
|
||||||
" FN = 0\n",
|
" FN = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Iterate over the DataFrame rows\n",
|
|
||||||
" for index, row in test_csv.iterrows():\n",
|
" for index, row in test_csv.iterrows():\n",
|
||||||
" # True Positive (TP): s_correct is True and MDM is True\n",
|
|
||||||
" if row['s_correct'] and row['MDM']:\n",
|
" if row['s_correct'] and row['MDM']:\n",
|
||||||
" TP += 1\n",
|
" TP += 1\n",
|
||||||
" # True Negative (TN): s_thing is null and MDM is False\n",
|
|
||||||
" elif row['s_thing'] == '' and not row['MDM']:\n",
|
" elif row['s_thing'] == '' and not row['MDM']:\n",
|
||||||
" TN += 1\n",
|
" TN += 1\n",
|
||||||
" # False Positive (FP): \n",
|
" elif (row['s_thing'] != '' and not row['MDM']):\n",
|
||||||
" # 1) s_thing is not null and MDM is False \n",
|
|
||||||
" # OR \n",
|
|
||||||
" # 2) s_thing is not null and s_correct is False and MDM is True\n",
|
|
||||||
" elif (row['s_thing'] != '' and not row['MDM']) or (row['s_thing'] != '' and not row['s_correct'] and row['MDM']):\n",
|
|
||||||
" FP += 1\n",
|
" FP += 1\n",
|
||||||
" # False Negative (FN): s_thing is null and MDM is True\n",
|
" elif row['s_thing'] == '' and row['MDM'] or (row['s_thing'] != '' and not row['s_correct'] and row['MDM']):\n",
|
||||||
" elif row['s_thing'] == '' and row['MDM']:\n",
|
|
||||||
" FN += 1\n",
|
" FN += 1\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Calculate total\n",
|
|
||||||
" total = TP + TN + FP + FN\n",
|
" total = TP + TN + FP + FN\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Calculate Precision, Recall, and Accuracy\n",
|
|
||||||
" precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n",
|
" precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n",
|
||||||
" recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n",
|
" recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n",
|
||||||
" accuracy = (TP + TN) / total if total > 0 else 0\n",
|
" accuracy = (TP + TN) / total if total > 0 else 0\n",
|
||||||
|
" f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return TP, TN, FP, FN, precision, recall, accuracy\n",
|
" return TP, TN, FP, FN, precision, recall, accuracy, f1_score\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Evaluate both datasets\n",
|
"# Lists to store performance metrics for all folds\n",
|
||||||
"tp_s_results = evaluate_performance(test_s_csv)\n",
|
"all_precisions = []\n",
|
||||||
|
"all_recalls = []\n",
|
||||||
|
"all_accuracies = []\n",
|
||||||
|
"all_f1_scores = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the results for both datasets\n",
|
"# Perform evaluation for group 1 to 5\n",
|
||||||
"print(\"Performance for all_with_p_s.csv:\")\n",
|
"for group_number in range(1, 6):\n",
|
||||||
"print(f\"TP: {tp_s_results[0]}, TN: {tp_s_results[1]}, FP: {tp_s_results[2]}, FN: {tp_s_results[3]}\")\n",
|
" test_s_path = f'../post_process/0.result/{group_number}/test_s.csv'\n",
|
||||||
"print(f\"Precision: {tp_s_results[4]:.4f}, Recall: {tp_s_results[5]:.4f}, Accuracy: {tp_s_results[6]:.4f}\")"
|
" test_s_csv = pd.read_csv(test_s_path, low_memory=False)\n",
|
||||||
|
" test_s_csv.fillna('', inplace=True)\n",
|
||||||
|
"\n",
|
||||||
|
" tp_s_results = evaluate_performance(test_s_csv)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Performance for group {group_number} (test_s.csv):\")\n",
|
||||||
|
" print(f\"TP: {tp_s_results[0]}, TN: {tp_s_results[1]}, FP: {tp_s_results[2]}, FN: {tp_s_results[3]}\")\n",
|
||||||
|
" print(f\"Precision: {tp_s_results[4]:.4f}, Recall: {tp_s_results[5]:.4f}, Accuracy: {tp_s_results[6]:.4f}, F1-Score: {tp_s_results[7]:.4f}\")\n",
|
||||||
|
" print(\"-\" * 50)\n",
|
||||||
|
"\n",
|
||||||
|
" all_precisions.append(tp_s_results[4])\n",
|
||||||
|
" all_recalls.append(tp_s_results[5])\n",
|
||||||
|
" all_accuracies.append(tp_s_results[6])\n",
|
||||||
|
" all_f1_scores.append(tp_s_results[7])\n",
|
||||||
|
"\n",
|
||||||
|
"# Calculate and print the averages across all groups\n",
|
||||||
|
"average_precision = sum(all_precisions) / len(all_precisions)\n",
|
||||||
|
"average_recall = sum(all_recalls) / len(all_recalls)\n",
|
||||||
|
"average_accuracy = sum(all_accuracies) / len(all_accuracies)\n",
|
||||||
|
"average_f1_score = sum(all_f1_scores) / len(all_f1_scores)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Average performance across all groups:\")\n",
|
||||||
|
"print(f\"Average Precision: {average_precision:.4f}\")\n",
|
||||||
|
"print(f\"Average Recall: {average_recall:.4f}\")\n",
|
||||||
|
"print(f\"Average Accuracy: {average_accuracy:.4f}\")\n",
|
||||||
|
"print(f\"Average F1-Score: {average_f1_score:.4f}\")\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -0,0 +1,341 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 35,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
||||||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1 completed. Loss: 5.564172267913818\n",
|
||||||
|
"Epoch 2 completed. Loss: 4.88321590423584\n",
|
||||||
|
"Epoch 3 completed. Loss: 3.5059947967529297\n",
|
||||||
|
"Epoch 4 completed. Loss: 3.18548583984375\n",
|
||||||
|
"Epoch 5 completed. Loss: 2.8037068843841553\n",
|
||||||
|
"Epoch 6 completed. Loss: 2.2223541736602783\n",
|
||||||
|
"Epoch 7 completed. Loss: 1.8634291887283325\n",
|
||||||
|
"Epoch 8 completed. Loss: 1.3251842260360718\n",
|
||||||
|
"Epoch 9 completed. Loss: 0.6083177328109741\n",
|
||||||
|
"Epoch 10 completed. Loss: 0.9423710703849792\n",
|
||||||
|
"Epoch 11 completed. Loss: 0.5799884796142578\n",
|
||||||
|
"Epoch 12 completed. Loss: 0.6948736310005188\n",
|
||||||
|
"Epoch 13 completed. Loss: 0.5177479386329651\n",
|
||||||
|
"Epoch 14 completed. Loss: 0.47343072295188904\n",
|
||||||
|
"Epoch 15 completed. Loss: 0.26853761076927185\n",
|
||||||
|
"Epoch 16 completed. Loss: 0.19693760573863983\n",
|
||||||
|
"Epoch 17 completed. Loss: 0.3199688494205475\n",
|
||||||
|
"Epoch 18 completed. Loss: 0.23672448098659515\n",
|
||||||
|
"Epoch 19 completed. Loss: 0.40235987305641174\n",
|
||||||
|
"Epoch 20 completed. Loss: 0.28102293610572815\n",
|
||||||
|
"Epoch 21 completed. Loss: 0.17575399577617645\n",
|
||||||
|
"Epoch 22 completed. Loss: 0.24652625620365143\n",
|
||||||
|
"Epoch 23 completed. Loss: 0.109055295586586\n",
|
||||||
|
"Epoch 24 completed. Loss: 0.19015412032604218\n",
|
||||||
|
"Epoch 25 completed. Loss: 0.10130400210618973\n",
|
||||||
|
"Epoch 26 completed. Loss: 0.14203257858753204\n",
|
||||||
|
"Epoch 27 completed. Loss: 0.1248723715543747\n",
|
||||||
|
"Epoch 28 completed. Loss: 0.05851107835769653\n",
|
||||||
|
"Epoch 29 completed. Loss: 0.041425254195928574\n",
|
||||||
|
"Epoch 30 completed. Loss: 0.0353962741792202\n",
|
||||||
|
"Epoch 31 completed. Loss: 0.04445452615618706\n",
|
||||||
|
"Epoch 32 completed. Loss: 0.026403019204735756\n",
|
||||||
|
"Epoch 33 completed. Loss: 0.028079884126782417\n",
|
||||||
|
"Epoch 34 completed. Loss: 0.059587348252534866\n",
|
||||||
|
"Epoch 35 completed. Loss: 0.02851276472210884\n",
|
||||||
|
"Epoch 36 completed. Loss: 0.09271513670682907\n",
|
||||||
|
"Epoch 37 completed. Loss: 0.06418397277593613\n",
|
||||||
|
"Epoch 38 completed. Loss: 0.03638231381773949\n",
|
||||||
|
"Epoch 39 completed. Loss: 0.022959664463996887\n",
|
||||||
|
"Epoch 40 completed. Loss: 0.044602662324905396\n",
|
||||||
|
"Epoch 41 completed. Loss: 0.03491249307990074\n",
|
||||||
|
"Epoch 42 completed. Loss: 0.039797600358724594\n",
|
||||||
|
"Epoch 43 completed. Loss: 0.04217083007097244\n",
|
||||||
|
"Epoch 44 completed. Loss: 0.4122176170349121\n",
|
||||||
|
"Epoch 45 completed. Loss: 0.1664775162935257\n",
|
||||||
|
"Epoch 46 completed. Loss: 0.04505300521850586\n",
|
||||||
|
"Epoch 47 completed. Loss: 0.14913827180862427\n",
|
||||||
|
"Epoch 48 completed. Loss: 0.016096509993076324\n",
|
||||||
|
"Epoch 49 completed. Loss: 0.05338064581155777\n",
|
||||||
|
"Epoch 50 completed. Loss: 0.10259533673524857\n",
|
||||||
|
"Epoch 51 completed. Loss: 0.008849691599607468\n",
|
||||||
|
"Epoch 52 completed. Loss: 0.028069255873560905\n",
|
||||||
|
"Epoch 53 completed. Loss: 0.008924427442252636\n",
|
||||||
|
"Epoch 54 completed. Loss: 0.015527592971920967\n",
|
||||||
|
"Epoch 55 completed. Loss: 0.009189464151859283\n",
|
||||||
|
"Epoch 56 completed. Loss: 0.007252057082951069\n",
|
||||||
|
"Epoch 57 completed. Loss: 0.01684846170246601\n",
|
||||||
|
"Epoch 58 completed. Loss: 0.010840333066880703\n",
|
||||||
|
"Epoch 59 completed. Loss: 0.05179211124777794\n",
|
||||||
|
"Epoch 60 completed. Loss: 0.007003726437687874\n",
|
||||||
|
"Epoch 61 completed. Loss: 0.00555015355348587\n",
|
||||||
|
"Epoch 62 completed. Loss: 0.0065276664681732655\n",
|
||||||
|
"Epoch 63 completed. Loss: 0.007942711934447289\n",
|
||||||
|
"Epoch 64 completed. Loss: 0.00675524678081274\n",
|
||||||
|
"Epoch 65 completed. Loss: 0.010359193198382854\n",
|
||||||
|
"Epoch 66 completed. Loss: 0.00662408908829093\n",
|
||||||
|
"Epoch 67 completed. Loss: 0.007672889623790979\n",
|
||||||
|
"Epoch 68 completed. Loss: 0.004661311395466328\n",
|
||||||
|
"Epoch 69 completed. Loss: 0.014480670914053917\n",
|
||||||
|
"Epoch 70 completed. Loss: 0.05042335391044617\n",
|
||||||
|
"Epoch 71 completed. Loss: 0.035947512835264206\n",
|
||||||
|
"Epoch 72 completed. Loss: 0.01213429868221283\n",
|
||||||
|
"Epoch 73 completed. Loss: 0.033572785556316376\n",
|
||||||
|
"Epoch 74 completed. Loss: 0.009208262898027897\n",
|
||||||
|
"Epoch 75 completed. Loss: 0.08961852639913559\n",
|
||||||
|
"Epoch 76 completed. Loss: 4.632999897003174\n",
|
||||||
|
"Epoch 77 completed. Loss: 5.957398891448975\n",
|
||||||
|
"Epoch 78 completed. Loss: 5.970841407775879\n",
|
||||||
|
"Epoch 79 completed. Loss: 5.905709266662598\n",
|
||||||
|
"Epoch 80 completed. Loss: 5.864459037780762\n",
|
||||||
|
"Validation Accuracy: 0.14%\n",
|
||||||
|
"Accuracy (MDM=True) for Group 4: 0.48%\n",
|
||||||
|
"Results saved to 0.class_document/albert/4/test_p_c.csv\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from transformers import AlbertTokenizer, AlbertForSequenceClassification, AdamW\n",
|
||||||
|
"from sklearn.preprocessing import LabelEncoder\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"import os \n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"group_number = 4\n",
|
||||||
|
"train_path = f'../../data_preprocess/dataset/{group_number}/train.csv'\n",
|
||||||
|
"valid_path = f'../../data_preprocess/dataset/{group_number}/valid.csv'\n",
|
||||||
|
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"output_path = f'0.class_document/albert/{group_number}/test_p_c.csv'\n",
|
||||||
|
"\n",
|
||||||
|
"train_data = pd.read_csv(train_path)\n",
|
||||||
|
"valid_data = pd.read_csv(valid_path)\n",
|
||||||
|
"test_data = pd.read_csv(test_path)\n",
|
||||||
|
"\n",
|
||||||
|
"train_data['thing_property'] = train_data['thing'] + '_' + train_data['property']\n",
|
||||||
|
"valid_data['thing_property'] = valid_data['thing'] + '_' + valid_data['property']\n",
|
||||||
|
"test_data['thing_property'] = test_data['thing'] + '_' + test_data['property']\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')\n",
|
||||||
|
"label_encoder = LabelEncoder()\n",
|
||||||
|
"label_encoder.fit(train_data['thing_property'])\n",
|
||||||
|
"\n",
|
||||||
|
"valid_data['thing_property'] = valid_data['thing_property'].apply(\n",
|
||||||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||||||
|
"test_data['thing_property'] = test_data['thing_property'].apply(\n",
|
||||||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||||||
|
"\n",
|
||||||
|
"label_encoder.classes_ = np.append(label_encoder.classes_, 'unknown_label')\n",
|
||||||
|
"\n",
|
||||||
|
"train_data['label'] = label_encoder.transform(train_data['thing_property'])\n",
|
||||||
|
"valid_data['label'] = label_encoder.transform(valid_data['thing_property'])\n",
|
||||||
|
"test_data['label'] = label_encoder.transform(test_data['thing_property'])\n",
|
||||||
|
"\n",
|
||||||
|
"train_texts, train_labels = train_data['tag_description'], train_data['label']\n",
|
||||||
|
"valid_texts, valid_labels = valid_data['tag_description'], valid_data['label']\n",
|
||||||
|
"\n",
|
||||||
|
"train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"valid_encodings = tokenizer(list(valid_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"\n",
|
||||||
|
"train_labels = torch.tensor(train_labels.values)\n",
|
||||||
|
"valid_labels = torch.tensor(valid_labels.values)\n",
|
||||||
|
"\n",
|
||||||
|
"class CustomDataset(Dataset):\n",
|
||||||
|
" def __init__(self, encodings, labels):\n",
|
||||||
|
" self.encodings = encodings\n",
|
||||||
|
" self.labels = labels\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
|
||||||
|
" item['labels'] = self.labels[idx]\n",
|
||||||
|
" return item\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.labels)\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = CustomDataset(train_encodings, train_labels)\n",
|
||||||
|
"valid_dataset = CustomDataset(valid_encodings, valid_labels)\n",
|
||||||
|
"\n",
|
||||||
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
||||||
|
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
"model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=len(train_data['thing_property'].unique()))\n",
|
||||||
|
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
|
||||||
|
"\n",
|
||||||
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"epochs = 80\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" for batch in train_loader:\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||||||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||||||
|
" labels = batch['labels'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
||||||
|
" loss = outputs.loss\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" print(f\"Epoch {epoch + 1} completed. Loss: {loss.item()}\")\n",
|
||||||
|
"\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"correct, total = 0, 0\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for batch in valid_loader:\n",
|
||||||
|
" input_ids, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
|
||||||
|
" correct += (predictions == labels).sum().item()\n",
|
||||||
|
" total += labels.size(0)\n",
|
||||||
|
"\n",
|
||||||
|
"valid_accuracy = correct / total\n",
|
||||||
|
"print(f'Validation Accuracy: {valid_accuracy * 100:.2f}%')\n",
|
||||||
|
"\n",
|
||||||
|
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
|
||||||
|
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) # 레이블은 사용되지 않으므로 임시로 0을 사용\n",
|
||||||
|
"\n",
|
||||||
|
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"predicted_thing_properties = []\n",
|
||||||
|
"predicted_scores = []\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for batch in test_loader:\n",
|
||||||
|
" input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
|
||||||
|
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
|
||||||
|
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
|
||||||
|
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
|
||||||
|
"\n",
|
||||||
|
"# 예측된 thing_property를 레이블 인코더로 디코딩\n",
|
||||||
|
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
|
||||||
|
"\n",
|
||||||
|
"# thing_property를 thing과 property로 나눔\n",
|
||||||
|
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
|
||||||
|
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
|
||||||
|
"test_data['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
|
||||||
|
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
|
||||||
|
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
|
||||||
|
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"test_data.to_csv(output_path, index=False)\n",
|
||||||
|
"print(f'Results saved to {output_path}')\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "AttributeError",
|
||||||
|
"evalue": "'AlbertForSequenceClassification' object has no attribute 'bert'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[29], line 20\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\u001b[38;5;241m.\u001b[39mmean(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy() \u001b[38;5;66;03m# 각 문장의 평균 임베딩 추출\u001b[39;00m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# BERT 모델로 임베딩 계산\u001b[39;00m\n\u001b[0;32m---> 20\u001b[0m bert_embeddings \u001b[38;5;241m=\u001b[39m \u001b[43mget_bert_embeddings\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfiltered_encodings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# t-SNE 차원 축소\u001b[39;00m\n\u001b[1;32m 23\u001b[0m tsne \u001b[38;5;241m=\u001b[39m TSNE(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m)\n",
|
||||||
|
"Cell \u001b[0;32mIn[29], line 16\u001b[0m, in \u001b[0;36mget_bert_embeddings\u001b[0;34m(model, encodings, device)\u001b[0m\n\u001b[1;32m 14\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m encodings[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 15\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m encodings[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 16\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert\u001b[49m(input_ids\u001b[38;5;241m=\u001b[39minput_ids, attention_mask\u001b[38;5;241m=\u001b[39mattention_mask)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\u001b[38;5;241m.\u001b[39mmean(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py:1709\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1707\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1708\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1709\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"\u001b[0;31mAttributeError\u001b[0m: 'AlbertForSequenceClassification' object has no attribute 'bert'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.manifold import TSNE\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"# 'filtered_data_plot.csv' 읽기\n",
|
||||||
|
"filtered_data = pd.read_csv('filtered_data_plot.csv')\n",
|
||||||
|
"\n",
|
||||||
|
"# 데이터 토큰화\n",
|
||||||
|
"filtered_encodings = tokenizer(list(filtered_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"\n",
|
||||||
|
"# BERT 임베딩 계산 함수\n",
|
||||||
|
"def get_bert_embeddings(model, encodings, device):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" input_ids = encodings['input_ids'].to(device)\n",
|
||||||
|
" attention_mask = encodings['attention_mask'].to(device)\n",
|
||||||
|
" outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # 각 문장의 평균 임베딩 추출\n",
|
||||||
|
"\n",
|
||||||
|
"# BERT 모델로 임베딩 계산\n",
|
||||||
|
"bert_embeddings = get_bert_embeddings(model, filtered_encodings, device)\n",
|
||||||
|
"\n",
|
||||||
|
"# t-SNE 차원 축소\n",
|
||||||
|
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||||||
|
"tsne_results = tsne.fit_transform(bert_embeddings)\n",
|
||||||
|
"\n",
|
||||||
|
"# 시각화를 위한 준비\n",
|
||||||
|
"unique_patterns = filtered_data['pattern'].unique()\n",
|
||||||
|
"color_map = plt.get_cmap('tab20', len(unique_patterns))\n",
|
||||||
|
"pattern_to_color = {pattern: idx for idx, pattern in enumerate(unique_patterns)}\n",
|
||||||
|
"\n",
|
||||||
|
"plt.figure(figsize=(14, 7))\n",
|
||||||
|
"\n",
|
||||||
|
"# 각 패턴별로 시각화\n",
|
||||||
|
"for pattern, color_idx in pattern_to_color.items():\n",
|
||||||
|
" pattern_indices = filtered_data['pattern'] == pattern\n",
|
||||||
|
" plt.scatter(tsne_results[pattern_indices, 0], tsne_results[pattern_indices, 1], \n",
|
||||||
|
" color=color_map(color_idx), marker='o', s=100, alpha=0.6, edgecolor='k', linewidth=1.2)\n",
|
||||||
|
"\n",
|
||||||
|
"# 그래프 설정\n",
|
||||||
|
"plt.xticks(fontsize=24)\n",
|
||||||
|
"plt.yticks(fontsize=24)\n",
|
||||||
|
"plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)\n",
|
||||||
|
"plt.tight_layout()\n",
|
||||||
|
"plt.show()\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,437 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
|
||||||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/optimization.py:521: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1 completed. Loss: 5.446770191192627\n",
|
||||||
|
"Validation Accuracy after Epoch 1: 18.30%\n",
|
||||||
|
"Epoch 2 completed. Loss: 3.8084073066711426\n",
|
||||||
|
"Validation Accuracy after Epoch 2: 40.87%\n",
|
||||||
|
"Epoch 3 completed. Loss: 3.0630860328674316\n",
|
||||||
|
"Validation Accuracy after Epoch 3: 65.36%\n",
|
||||||
|
"Epoch 4 completed. Loss: 1.5352345705032349\n",
|
||||||
|
"Validation Accuracy after Epoch 4: 73.26%\n",
|
||||||
|
"Epoch 5 completed. Loss: 0.8989766836166382\n",
|
||||||
|
"Validation Accuracy after Epoch 5: 78.01%\n",
|
||||||
|
"Epoch 6 completed. Loss: 0.9589817523956299\n",
|
||||||
|
"Validation Accuracy after Epoch 6: 81.65%\n",
|
||||||
|
"Epoch 7 completed. Loss: 0.7892795205116272\n",
|
||||||
|
"Validation Accuracy after Epoch 7: 83.85%\n",
|
||||||
|
"Epoch 8 completed. Loss: 0.5069147944450378\n",
|
||||||
|
"Validation Accuracy after Epoch 8: 86.97%\n",
|
||||||
|
"Epoch 9 completed. Loss: 0.524911642074585\n",
|
||||||
|
"Validation Accuracy after Epoch 9: 88.12%\n",
|
||||||
|
"Epoch 10 completed. Loss: 0.2070937305688858\n",
|
||||||
|
"Validation Accuracy after Epoch 10: 89.94%\n",
|
||||||
|
"Epoch 11 completed. Loss: 0.19738677144050598\n",
|
||||||
|
"Validation Accuracy after Epoch 11: 90.75%\n",
|
||||||
|
"Epoch 12 completed. Loss: 0.13339389860630035\n",
|
||||||
|
"Validation Accuracy after Epoch 12: 91.90%\n",
|
||||||
|
"Epoch 13 completed. Loss: 0.21022899448871613\n",
|
||||||
|
"Validation Accuracy after Epoch 13: 92.86%\n",
|
||||||
|
"Epoch 14 completed. Loss: 0.26752030849456787\n",
|
||||||
|
"Validation Accuracy after Epoch 14: 93.24%\n",
|
||||||
|
"Epoch 15 completed. Loss: 0.14866866171360016\n",
|
||||||
|
"Validation Accuracy after Epoch 15: 93.68%\n",
|
||||||
|
"Epoch 16 completed. Loss: 0.08989054709672928\n",
|
||||||
|
"Validation Accuracy after Epoch 16: 94.06%\n",
|
||||||
|
"Epoch 17 completed. Loss: 0.037873975932598114\n",
|
||||||
|
"Validation Accuracy after Epoch 17: 94.59%\n",
|
||||||
|
"Epoch 18 completed. Loss: 0.07367080450057983\n",
|
||||||
|
"Validation Accuracy after Epoch 18: 94.68%\n",
|
||||||
|
"Epoch 19 completed. Loss: 0.04101959988474846\n",
|
||||||
|
"Validation Accuracy after Epoch 19: 94.83%\n",
|
||||||
|
"Epoch 20 completed. Loss: 0.21339105069637299\n",
|
||||||
|
"Validation Accuracy after Epoch 20: 95.02%\n",
|
||||||
|
"Epoch 21 completed. Loss: 0.06965143978595734\n",
|
||||||
|
"Validation Accuracy after Epoch 21: 94.97%\n",
|
||||||
|
"Epoch 22 completed. Loss: 0.06043635308742523\n",
|
||||||
|
"Validation Accuracy after Epoch 22: 95.02%\n",
|
||||||
|
"Epoch 23 completed. Loss: 0.021217377856373787\n",
|
||||||
|
"Validation Accuracy after Epoch 23: 94.92%\n",
|
||||||
|
"Epoch 24 completed. Loss: 0.037467293441295624\n",
|
||||||
|
"Validation Accuracy after Epoch 24: 95.02%\n",
|
||||||
|
"Epoch 25 completed. Loss: 0.016836028546094894\n",
|
||||||
|
"Validation Accuracy after Epoch 25: 95.02%\n",
|
||||||
|
"Epoch 26 completed. Loss: 0.028664518147706985\n",
|
||||||
|
"Validation Accuracy after Epoch 26: 95.11%\n",
|
||||||
|
"Epoch 27 completed. Loss: 0.011028420180082321\n",
|
||||||
|
"Validation Accuracy after Epoch 27: 95.16%\n",
|
||||||
|
"Epoch 28 completed. Loss: 0.04282907024025917\n",
|
||||||
|
"Validation Accuracy after Epoch 28: 95.16%\n",
|
||||||
|
"Epoch 29 completed. Loss: 0.00940023921430111\n",
|
||||||
|
"Validation Accuracy after Epoch 29: 95.35%\n",
|
||||||
|
"Epoch 30 completed. Loss: 0.13019809126853943\n",
|
||||||
|
"Validation Accuracy after Epoch 30: 95.35%\n",
|
||||||
|
"Epoch 31 completed. Loss: 0.01270432397723198\n",
|
||||||
|
"Validation Accuracy after Epoch 31: 95.11%\n",
|
||||||
|
"Epoch 32 completed. Loss: 0.012832771986722946\n",
|
||||||
|
"Validation Accuracy after Epoch 32: 95.16%\n",
|
||||||
|
"Epoch 33 completed. Loss: 0.012174545787274837\n",
|
||||||
|
"Validation Accuracy after Epoch 33: 95.16%\n",
|
||||||
|
"Epoch 34 completed. Loss: 0.02090534381568432\n",
|
||||||
|
"Validation Accuracy after Epoch 34: 95.02%\n",
|
||||||
|
"Epoch 35 completed. Loss: 0.017653826624155045\n",
|
||||||
|
"Validation Accuracy after Epoch 35: 94.49%\n",
|
||||||
|
"Epoch 36 completed. Loss: 0.02190311811864376\n",
|
||||||
|
"Validation Accuracy after Epoch 36: 94.59%\n",
|
||||||
|
"Epoch 37 completed. Loss: 0.048320867121219635\n",
|
||||||
|
"Validation Accuracy after Epoch 37: 94.68%\n",
|
||||||
|
"Epoch 38 completed. Loss: 0.015598177909851074\n",
|
||||||
|
"Validation Accuracy after Epoch 38: 95.30%\n",
|
||||||
|
"Epoch 39 completed. Loss: 0.009368035942316055\n",
|
||||||
|
"Validation Accuracy after Epoch 39: 94.83%\n",
|
||||||
|
"Epoch 40 completed. Loss: 0.009023590944707394\n",
|
||||||
|
"Validation Accuracy after Epoch 40: 95.02%\n",
|
||||||
|
"Epoch 41 completed. Loss: 0.040157418698072433\n",
|
||||||
|
"Validation Accuracy after Epoch 41: 95.11%\n",
|
||||||
|
"Epoch 42 completed. Loss: 0.11878462135791779\n",
|
||||||
|
"Validation Accuracy after Epoch 42: 95.06%\n",
|
||||||
|
"Epoch 43 completed. Loss: 0.021250683814287186\n",
|
||||||
|
"Validation Accuracy after Epoch 43: 95.16%\n",
|
||||||
|
"Epoch 44 completed. Loss: 0.0023518940433859825\n",
|
||||||
|
"Validation Accuracy after Epoch 44: 95.16%\n",
|
||||||
|
"Epoch 45 completed. Loss: 0.00595875782892108\n",
|
||||||
|
"Validation Accuracy after Epoch 45: 95.16%\n",
|
||||||
|
"Epoch 46 completed. Loss: 0.0025296895764768124\n",
|
||||||
|
"Validation Accuracy after Epoch 46: 94.97%\n",
|
||||||
|
"Epoch 47 completed. Loss: 0.0753568485379219\n",
|
||||||
|
"Validation Accuracy after Epoch 47: 95.26%\n",
|
||||||
|
"Epoch 48 completed. Loss: 0.002112493384629488\n",
|
||||||
|
"Validation Accuracy after Epoch 48: 95.06%\n",
|
||||||
|
"Epoch 49 completed. Loss: 0.09600060433149338\n",
|
||||||
|
"Validation Accuracy after Epoch 49: 95.06%\n",
|
||||||
|
"Epoch 50 completed. Loss: 0.002454130444675684\n",
|
||||||
|
"Validation Accuracy after Epoch 50: 95.21%\n",
|
||||||
|
"Accuracy (MDM=True) for Group 5: 91.98%\n",
|
||||||
|
"Results saved to 0.class_document/distilbert/5/test_p_c.csv\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW\n",
|
||||||
|
"from sklearn.preprocessing import LabelEncoder\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"import os \n",
|
||||||
|
"\n",
|
||||||
|
"group_number = 5\n",
|
||||||
|
"train_path = f'../../data_preprocess/dataset/{group_number}/train.csv'\n",
|
||||||
|
"valid_path = f'../../data_preprocess/dataset/{group_number}/valid.csv'\n",
|
||||||
|
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"output_path = f'0.class_document/distilbert/{group_number}/test_p_c.csv'\n",
|
||||||
|
"\n",
|
||||||
|
"train_data = pd.read_csv(train_path)\n",
|
||||||
|
"valid_data = pd.read_csv(valid_path)\n",
|
||||||
|
"test_data = pd.read_csv(test_path)\n",
|
||||||
|
"\n",
|
||||||
|
"train_data['thing_property'] = train_data['thing'] + '_' + train_data['property']\n",
|
||||||
|
"valid_data['thing_property'] = valid_data['thing'] + '_' + valid_data['property']\n",
|
||||||
|
"test_data['thing_property'] = test_data['thing'] + '_' + test_data['property']\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
|
||||||
|
"label_encoder = LabelEncoder()\n",
|
||||||
|
"label_encoder.fit(train_data['thing_property'])\n",
|
||||||
|
"\n",
|
||||||
|
"valid_data['thing_property'] = valid_data['thing_property'].apply(\n",
|
||||||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||||||
|
"test_data['thing_property'] = test_data['thing_property'].apply(\n",
|
||||||
|
" lambda x: x if x in label_encoder.classes_ else 'unknown_label')\n",
|
||||||
|
"\n",
|
||||||
|
"label_encoder.classes_ = np.append(label_encoder.classes_, 'unknown_label')\n",
|
||||||
|
"\n",
|
||||||
|
"train_data['label'] = label_encoder.transform(train_data['thing_property'])\n",
|
||||||
|
"valid_data['label'] = label_encoder.transform(valid_data['thing_property'])\n",
|
||||||
|
"test_data['label'] = label_encoder.transform(test_data['thing_property'])\n",
|
||||||
|
"\n",
|
||||||
|
"train_texts, train_labels = train_data['tag_description'], train_data['label']\n",
|
||||||
|
"valid_texts, valid_labels = valid_data['tag_description'], valid_data['label']\n",
|
||||||
|
"\n",
|
||||||
|
"train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"valid_encodings = tokenizer(list(valid_texts), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"\n",
|
||||||
|
"train_labels = torch.tensor(train_labels.values)\n",
|
||||||
|
"valid_labels = torch.tensor(valid_labels.values)\n",
|
||||||
|
"\n",
|
||||||
|
"class CustomDataset(Dataset):\n",
|
||||||
|
" def __init__(self, encodings, labels):\n",
|
||||||
|
" self.encodings = encodings\n",
|
||||||
|
" self.labels = labels\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
|
||||||
|
" item['labels'] = self.labels[idx]\n",
|
||||||
|
" return item\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.labels)\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = CustomDataset(train_encodings, train_labels)\n",
|
||||||
|
"valid_dataset = CustomDataset(valid_encodings, valid_labels)\n",
|
||||||
|
"\n",
|
||||||
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
||||||
|
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
"model = DistilBertForSequenceClassification.from_pretrained(\n",
|
||||||
|
" 'distilbert-base-uncased', \n",
|
||||||
|
" num_labels=len(train_data['thing_property'].unique())\n",
|
||||||
|
")\n",
|
||||||
|
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
|
||||||
|
"\n",
|
||||||
|
"device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"epochs = 50\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" for batch in train_loader:\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||||||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||||||
|
" labels = batch['labels'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
||||||
|
" loss = outputs.loss\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" print(f\"Epoch {epoch + 1} completed. Loss: {loss.item()}\")\n",
|
||||||
|
"\n",
|
||||||
|
" # 검증 루프\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" correct, total = 0, 0\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for batch in valid_loader:\n",
|
||||||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||||||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||||||
|
" labels = batch['labels'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
|
||||||
|
" correct += (predictions == labels).sum().item()\n",
|
||||||
|
" total += labels.size(0)\n",
|
||||||
|
"\n",
|
||||||
|
" valid_accuracy = correct / total\n",
|
||||||
|
" print(f'Validation Accuracy after Epoch {epoch + 1}: {valid_accuracy * 100:.2f}%')\n",
|
||||||
|
"\n",
|
||||||
|
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
|
||||||
|
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) # 레이블은 사용되지 않으므로 임시로 0을 사용\n",
|
||||||
|
"\n",
|
||||||
|
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"predicted_thing_properties = []\n",
|
||||||
|
"predicted_scores = []\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for batch in test_loader:\n",
|
||||||
|
" input_ids = batch['input_ids'].to(device)\n",
|
||||||
|
" attention_mask = batch['attention_mask'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
|
||||||
|
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
|
||||||
|
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
|
||||||
|
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
|
||||||
|
"\n",
|
||||||
|
"# 예측된 thing_property를 레이블 인코더로 디코딩\n",
|
||||||
|
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
|
||||||
|
"\n",
|
||||||
|
"# thing_property를 thing과 property로 나눔\n",
|
||||||
|
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
|
||||||
|
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
|
||||||
|
"test_data['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
|
||||||
|
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
|
||||||
|
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
|
||||||
|
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# 결과를 저장하기 전에 폴더가 존재하는지 확인하고, 없으면 생성\n",
|
||||||
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"test_data.to_csv(output_path, index=False)\n",
|
||||||
|
"print(f'Results saved to {output_path}')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Validation Accuracy: 95.21%\n",
|
||||||
|
"Accuracy (MDM=True) for Group 5: 91.98%\n",
|
||||||
|
"Results saved to 0.class_document/distilbert/5/test_p_c.csv\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# 검증 루프\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"correct, total = 0, 0\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for batch in valid_loader:\n",
|
||||||
|
" input_ids, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
|
||||||
|
" correct += (predictions == labels).sum().item()\n",
|
||||||
|
" total += labels.size(0)\n",
|
||||||
|
"\n",
|
||||||
|
"valid_accuracy = correct / total\n",
|
||||||
|
"print(f'Validation Accuracy: {valid_accuracy * 100:.2f}%')\n",
|
||||||
|
"\n",
|
||||||
|
"# Test 데이터 예측 및 c_thing, c_property 추가\n",
|
||||||
|
"test_encodings = tokenizer(list(test_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"test_dataset = CustomDataset(test_encodings, torch.zeros(len(test_data))) # 레이블은 사용되지 않으므로 임시로 0을 사용\n",
|
||||||
|
"\n",
|
||||||
|
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"predicted_thing_properties = []\n",
|
||||||
|
"predicted_scores = []\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for batch in test_loader:\n",
|
||||||
|
" input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)\n",
|
||||||
|
" outputs = model(input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" softmax_scores = F.softmax(outputs.logits, dim=-1)\n",
|
||||||
|
" predictions = torch.argmax(softmax_scores, dim=-1)\n",
|
||||||
|
" predicted_thing_properties.extend(predictions.cpu().numpy())\n",
|
||||||
|
" predicted_scores.extend(softmax_scores[range(len(predictions)), predictions].cpu().numpy())\n",
|
||||||
|
"\n",
|
||||||
|
"# 예측된 thing_property를 레이블 인코더로 디코딩\n",
|
||||||
|
"predicted_thing_property_labels = label_encoder.inverse_transform(predicted_thing_properties)\n",
|
||||||
|
"\n",
|
||||||
|
"# thing_property를 thing과 property로 나눔\n",
|
||||||
|
"test_data['c_thing'] = [x.split('_')[0] for x in predicted_thing_property_labels]\n",
|
||||||
|
"test_data['c_property'] = [x.split('_')[1] for x in predicted_thing_property_labels]\n",
|
||||||
|
"test_data['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
"test_data['cthing_correct'] = test_data['thing'] == test_data['c_thing']\n",
|
||||||
|
"test_data['cproperty_correct'] = test_data['property'] == test_data['c_property']\n",
|
||||||
|
"test_data['ctp_correct'] = test_data['cthing_correct'] & test_data['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
"mdm_true_count = len(test_data[test_data['MDM'] == True])\n",
|
||||||
|
"accuracy = (test_data['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"test_data.to_csv(output_path, index=False)\n",
|
||||||
|
"print(f'Results saved to {output_path}')\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "NameError",
|
||||||
|
"evalue": "name 'pd' is not defined",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[3], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# 'filtered_data_plot.csv' 읽기\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m filtered_data \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241m.\u001b[39mread_csv(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfiltered_data_plot.csv\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# 데이터 토큰화\u001b[39;00m\n\u001b[1;32m 8\u001b[0m filtered_encodings \u001b[38;5;241m=\u001b[39m tokenizer(\u001b[38;5;28mlist\u001b[39m(filtered_data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtag_description\u001b[39m\u001b[38;5;124m'\u001b[39m]), truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
||||||
|
"\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.manifold import TSNE\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"# 'filtered_data_plot.csv' 읽기\n",
|
||||||
|
"filtered_data = pd.read_csv('filtered_data_plot.csv')\n",
|
||||||
|
"\n",
|
||||||
|
"# 데이터 토큰화\n",
|
||||||
|
"filtered_encodings = tokenizer(list(filtered_data['tag_description']), truncation=True, padding=True, return_tensors='pt')\n",
|
||||||
|
"\n",
|
||||||
|
"# BERT 임베딩 계산 함수\n",
|
||||||
|
"def get_bert_embeddings(model, encodings, device):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" input_ids = encodings['input_ids'].to(device)\n",
|
||||||
|
" attention_mask = encodings['attention_mask'].to(device)\n",
|
||||||
|
" outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
|
||||||
|
" return outputs.last_hidden_state.mean(dim=1).cpu().numpy() # 각 문장의 평균 임베딩 추출\n",
|
||||||
|
"\n",
|
||||||
|
"# BERT 모델로 임베딩 계산\n",
|
||||||
|
"bert_embeddings = get_bert_embeddings(model, filtered_encodings, device)\n",
|
||||||
|
"\n",
|
||||||
|
"# t-SNE 차원 축소\n",
|
||||||
|
"tsne = TSNE(n_components=2, random_state=42)\n",
|
||||||
|
"tsne_results = tsne.fit_transform(bert_embeddings)\n",
|
||||||
|
"\n",
|
||||||
|
"# 시각화를 위한 준비\n",
|
||||||
|
"unique_patterns = filtered_data['pattern'].unique()\n",
|
||||||
|
"color_map = plt.get_cmap('tab20', len(unique_patterns))\n",
|
||||||
|
"pattern_to_color = {pattern: idx for idx, pattern in enumerate(unique_patterns)}\n",
|
||||||
|
"\n",
|
||||||
|
"plt.figure(figsize=(14, 7))\n",
|
||||||
|
"\n",
|
||||||
|
"# 각 패턴별로 시각화\n",
|
||||||
|
"for pattern, color_idx in pattern_to_color.items():\n",
|
||||||
|
" pattern_indices = filtered_data['pattern'] == pattern\n",
|
||||||
|
" plt.scatter(tsne_results[pattern_indices, 0], tsne_results[pattern_indices, 1], \n",
|
||||||
|
" color=color_map(color_idx), marker='o', s=100, alpha=0.6, edgecolor='k', linewidth=1.2)\n",
|
||||||
|
"\n",
|
||||||
|
"# 그래프 설정\n",
|
||||||
|
"plt.xticks(fontsize=24)\n",
|
||||||
|
"plt.yticks(fontsize=24)\n",
|
||||||
|
"plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)\n",
|
||||||
|
"plt.tight_layout()\n",
|
||||||
|
"plt.show()\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
|
@ -2,40 +2,36 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Accuracy (MDM=True) for Group 1: 79.41%\n",
|
"Accuracy (MDM=True) for Group 1: 73.50%\n",
|
||||||
"Accuracy (MDM=True) for Group 2: 79.32%\n",
|
"Accuracy (MDM=True) for Group 2: 78.04%\n",
|
||||||
"Accuracy (MDM=True) for Group 3: 82.49%\n",
|
"Accuracy (MDM=True) for Group 3: 81.73%\n",
|
||||||
"Accuracy (MDM=True) for Group 4: 85.61%\n",
|
"Accuracy (MDM=True) for Group 4: 79.83%\n",
|
||||||
"Accuracy (MDM=True) for Group 5: 79.72%\n",
|
"Accuracy (MDM=True) for Group 5: 81.31%\n",
|
||||||
"Average Accuracy (MDM=True) across all groups: 81.31%\n"
|
"Average Accuracy (MDM=True) across all groups: 78.88%\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||||
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
"from sklearn.metrics import pairwise_distances\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Initialize a list to store the accuracies for each group\n",
|
|
||||||
"accuracies = []\n",
|
"accuracies = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Loop through group numbers from 1 to 5\n",
|
|
||||||
"for group_number in range(1, 6):\n",
|
"for group_number in range(1, 6):\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Load the CSV files from the specified group\n",
|
|
||||||
" sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
|
" sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
|
||||||
" test_path = f'../../data_preprocess/dataset/{group_number}/test.csv'\n",
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Check if test file exists, if not, skip this iteration\n",
|
|
||||||
" if not os.path.exists(test_path):\n",
|
" if not os.path.exists(test_path):\n",
|
||||||
" print(f\"test file for Group {group_number} does not exist. Skipping...\")\n",
|
" print(f\"test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
" continue\n",
|
" continue\n",
|
||||||
|
@ -43,68 +39,54 @@
|
||||||
" sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
|
" sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
|
||||||
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Replace NaN values with empty strings in relevant columns\n",
|
|
||||||
" sdl_class_rdoc_csv['tag_description'] = sdl_class_rdoc_csv['tag_description'].fillna('')\n",
|
" sdl_class_rdoc_csv['tag_description'] = sdl_class_rdoc_csv['tag_description'].fillna('')\n",
|
||||||
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Initialize new columns in test_csv\n",
|
|
||||||
" test_csv['c_thing'] = ''\n",
|
" test_csv['c_thing'] = ''\n",
|
||||||
" test_csv['c_property'] = ''\n",
|
" test_csv['c_property'] = ''\n",
|
||||||
" test_csv['c_score'] = ''\n",
|
" test_csv['c_score'] = ''\n",
|
||||||
" test_csv['c_duplicate'] = 0 # Initialize c_duplicate to store duplicate counts\n",
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Combine both sdl_class_rdoc and test CSVs tag_descriptions for TF-IDF Vectorizer training\n",
|
|
||||||
" combined_tag_descriptions = sdl_class_rdoc_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
|
" combined_tag_descriptions = sdl_class_rdoc_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Create a TF-IDF Vectorizer\n",
|
|
||||||
" vectorizer = TfidfVectorizer(\n",
|
" vectorizer = TfidfVectorizer(\n",
|
||||||
|
" use_idf=True, \n",
|
||||||
" token_pattern=r'\\S+',\n",
|
" token_pattern=r'\\S+',\n",
|
||||||
" ngram_range=(1, 6), # Use ngrams from 1 to 6\n",
|
" ngram_range=(1, 1),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Fit the TF-IDF vectorizer on the combined tag_descriptions\n",
|
|
||||||
" vectorizer.fit(combined_tag_descriptions)\n",
|
" vectorizer.fit(combined_tag_descriptions)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Transform both sdl_class_rdoc and test CSVs into TF-IDF matrices\n",
|
|
||||||
" sdl_class_rdoc_tfidf_matrix = vectorizer.transform(sdl_class_rdoc_csv['tag_description'])\n",
|
" sdl_class_rdoc_tfidf_matrix = vectorizer.transform(sdl_class_rdoc_csv['tag_description'])\n",
|
||||||
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Calculate cosine similarity between test and class-level sdl_class_rdoc vectors\n",
|
" distance_matrix = pairwise_distances(test_tfidf_matrix, sdl_class_rdoc_tfidf_matrix, metric='cosine')\n",
|
||||||
" similarity_matrix = cosine_similarity(test_tfidf_matrix, sdl_class_rdoc_tfidf_matrix)\n",
|
|
||||||
" \n",
|
" \n",
|
||||||
" # Find the most similar class-level tag_description for each test description\n",
|
" most_similar_indices = distance_matrix.argmin(axis=1)\n",
|
||||||
" most_similar_indices = similarity_matrix.argmax(axis=1)\n",
|
" most_similar_scores = 1 - distance_matrix.min(axis=1)\n",
|
||||||
" most_similar_scores = similarity_matrix.max(axis=1)\n",
|
|
||||||
" \n",
|
" \n",
|
||||||
" # Assign the corresponding thing, property, and similarity score to the test CSV\n",
|
|
||||||
" test_csv['c_thing'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['thing'].values\n",
|
" test_csv['c_thing'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['thing'].values\n",
|
||||||
" test_csv['c_property'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['property'].values\n",
|
" test_csv['c_property'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['property'].values\n",
|
||||||
" test_csv['c_score'] = most_similar_scores\n",
|
" test_csv['c_score'] = most_similar_scores\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Check if the predicted 'c_thing' and 'c_property' match the actual 'thing' and 'property'\n",
|
|
||||||
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Calculate accuracy based only on MDM = True\n",
|
|
||||||
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
" accuracies.append(accuracy)\n",
|
" accuracies.append(accuracy)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
" print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Specify output file paths\n",
|
|
||||||
" output_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
" output_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
||||||
" test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
" test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Filter for rows where MDM is True and ctp_correct is False\n",
|
|
||||||
" false_positive_rows = test_csv[(test_csv['MDM'] == True) & (test_csv['ctp_correct'] == False)]\n",
|
" false_positive_rows = test_csv[(test_csv['MDM'] == True) & (test_csv['ctp_correct'] == False)]\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Save false positives to a separate file\n",
|
|
||||||
" fp_output_path = f'0.class_document/{group_number}/fp_class.csv'\n",
|
" fp_output_path = f'0.class_document/{group_number}/fp_class.csv'\n",
|
||||||
" false_positive_rows.to_csv(fp_output_path, index=False, encoding='utf-8-sig')\n",
|
" false_positive_rows.to_csv(fp_output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Calculate and print the average accuracy across all groups\n",
|
|
||||||
"average_accuracy = sum(accuracies) / len(accuracies)\n",
|
"average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
"print(f\"Average Accuracy (MDM=True) across all groups: {average_accuracy:.2f}%\")\n"
|
"print(f\"Average Accuracy (MDM=True) across all groups: {average_accuracy:.2f}%\")\n"
|
||||||
]
|
]
|
|
@ -0,0 +1,116 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=1: 84.43%\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "KeyboardInterrupt",
|
||||||
|
"evalue": "",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6311\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6312\u001b[0;31m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;31mAttributeError\u001b[0m: 'Series' object has no attribute '_name'",
|
||||||
|
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_89094/2696322053.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mknn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkneighbors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_bow_matrix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mpredicted_things\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain_all_csv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'thing'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0mpredicted_properties\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtrain_all_csv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'property'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0mpredicted_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mdistances\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_csv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_thing'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_property'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_csv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'c_score'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredicted_things\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted_properties\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted_scores\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m/tmp/ipykernel_89094/2696322053.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpandas\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1187\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1189\u001b[0m \u001b[0mmaybe_callable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_if_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[0mmaybe_callable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_deprecated_callable_usage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaybe_callable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmaybe_callable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1750\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0;31m# validate the location\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_integer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1753\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1754\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ixs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, i, axis)\u001b[0m\n\u001b[1;32m 3996\u001b[0m \u001b[0mnew_mgr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfast_xs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3997\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3998\u001b[0m \u001b[0;31m# if we are a copy, mark as such\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3999\u001b[0m \u001b[0mcopy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4000\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_constructor_sliced_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4001\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4002\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__finalize__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4003\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_is_copy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, mgr, axes)\u001b[0m\n\u001b[1;32m 678\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_constructor_sliced_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 679\u001b[0m \u001b[0mser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_from_mgr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmgr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 680\u001b[0;31m \u001b[0mser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# caller is responsible for setting real name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 681\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 682\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mDataFrame\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 683\u001b[0m \u001b[0;31m# This would also work `if self._constructor_sliced is Series`, but\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6308\u001b[0m \u001b[0;31m# e.g. ``obj.x`` and ``obj.x = 4`` will always reference/modify\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6309\u001b[0m \u001b[0;31m# the same attribute.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6311\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6312\u001b[0;31m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6314\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6315\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
||||||
|
"from sklearn.neighbors import NearestNeighbors\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"average_accuracies = []\n",
|
||||||
|
"\n",
|
||||||
|
"for n in range(1, 53):\n",
|
||||||
|
" accuracies = []\n",
|
||||||
|
" for group_number in range(1, 6):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'], test_csv['c_property'], test_csv['c_score'], test_csv['c_duplicate'] = '', '', '', 0\n",
|
||||||
|
"\n",
|
||||||
|
" vectorizer = CountVectorizer(token_pattern=r'\\S+', ngram_range=(1, 1))\n",
|
||||||
|
" train_all_bow_matrix = vectorizer.fit_transform(train_all_csv['tag_description'])\n",
|
||||||
|
" test_bow_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
||||||
|
"\n",
|
||||||
|
" knn = NearestNeighbors(n_neighbors=n, metric='euclidean', n_jobs=-1)\n",
|
||||||
|
" knn.fit(train_all_bow_matrix)\n",
|
||||||
|
"\n",
|
||||||
|
" distances, indices = knn.kneighbors(test_bow_matrix)\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things = [train_all_csv.iloc[indices[i][0]]['thing'] for i in range(len(test_csv))]\n",
|
||||||
|
" predicted_properties = [train_all_csv.iloc[indices[i][0]]['property'] for i in range(len(test_csv))]\n",
|
||||||
|
" predicted_scores = [1 - distances[i][0] for i in range(len(test_csv))]\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'], test_csv['c_property'], test_csv['c_score'] = predicted_things, predicted_properties, predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracies.append((test_csv['ctp_correct'].sum() / mdm_true_count) * 100)\n",
|
||||||
|
"\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies.append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with n_neighbors={n}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\nFinal Results:\")\n",
|
||||||
|
"for n, avg_accuracy in zip(range(1, 53), average_accuracies):\n",
|
||||||
|
" print(f\"n_neighbors={n}, Average Accuracy: {avg_accuracy:.2f}%\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,142 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=5: 86.09%\n",
|
||||||
|
"\n",
|
||||||
|
"Final Results:\n",
|
||||||
|
"n_neighbors=1, Average Accuracy: 86.09%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
||||||
|
"from sklearn.neighbors import NearestNeighbors\n",
|
||||||
|
"import os\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from joblib import Parallel, delayed\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize variables to store overall accuracy results\n",
|
||||||
|
"average_accuracies = []\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to process each group (parallelized later)\n",
|
||||||
|
"def process_group(n, group_number):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
" # Load the train_all and test CSVs\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = ''\n",
|
||||||
|
" test_csv['c_property'] = ''\n",
|
||||||
|
" test_csv['c_score'] = ''\n",
|
||||||
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" combined_tag_descriptions = train_all_csv['tag_description'].tolist()\n",
|
||||||
|
"\n",
|
||||||
|
" # BoW를 Boolean 방식으로 변환\n",
|
||||||
|
" vectorizer = CountVectorizer(token_pattern=r'\\S+', binary=True)\n",
|
||||||
|
" vectorizer.fit(combined_tag_descriptions)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_bow_matrix = vectorizer.transform(train_all_csv['tag_description']).toarray().astype(bool) # bool로 변환\n",
|
||||||
|
" test_bow_matrix = vectorizer.transform(test_csv['tag_description']).toarray().astype(bool)\n",
|
||||||
|
"\n",
|
||||||
|
" # NearestNeighbors에서 Jaccard 유사도를 사용 (모든 CPU 사용)\n",
|
||||||
|
" knn = NearestNeighbors(n_neighbors=n, metric='jaccard', n_jobs=-1) # n_jobs=-1로 모든 CPU 사용\n",
|
||||||
|
" knn.fit(train_all_bow_matrix)\n",
|
||||||
|
"\n",
|
||||||
|
" distances, indices = knn.kneighbors(test_bow_matrix)\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things = []\n",
|
||||||
|
" predicted_properties = []\n",
|
||||||
|
" predicted_scores = []\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(len(test_csv)):\n",
|
||||||
|
" neighbor_index = indices[i][0]\n",
|
||||||
|
" distance = distances[i][0]\n",
|
||||||
|
"\n",
|
||||||
|
" neighbor_thing = train_all_csv.iloc[neighbor_index]['thing']\n",
|
||||||
|
" neighbor_property = train_all_csv.iloc[neighbor_index]['property']\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things.append(neighbor_thing)\n",
|
||||||
|
" predicted_properties.append(neighbor_property)\n",
|
||||||
|
"\n",
|
||||||
|
" # Jaccard 유사도는 1 - 거리로 계산\n",
|
||||||
|
" predicted_score = 1 - distance\n",
|
||||||
|
" predicted_scores.append(predicted_score)\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = predicted_things\n",
|
||||||
|
" test_csv['c_property'] = predicted_properties\n",
|
||||||
|
" test_csv['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
|
" if(n==5): \n",
|
||||||
|
" output_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
||||||
|
" test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
|
"\n",
|
||||||
|
" return accuracy\n",
|
||||||
|
"\n",
|
||||||
|
"# Loop through n_neighbors values from 1 to 52\n",
|
||||||
|
"for n in range(5, 6):\n",
|
||||||
|
" # Parallel processing for groups\n",
|
||||||
|
" results = Parallel(n_jobs=-1)(delayed(process_group)(n, group_number) for group_number in range(1, 6))\n",
|
||||||
|
"\n",
|
||||||
|
" # Filter out None results (in case of missing files)\n",
|
||||||
|
" accuracies = [result for result in results if result is not None]\n",
|
||||||
|
"\n",
|
||||||
|
" if accuracies:\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies.append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with n_neighbors={n}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Print overall results for all n_neighbors values\n",
|
||||||
|
"print(\"\\nFinal Results:\")\n",
|
||||||
|
"for n, avg_accuracy in zip(range(1, 53), average_accuracies):\n",
|
||||||
|
" print(f\"n_neighbors={n}, Average Accuracy: {avg_accuracy:.2f}%\")\n",
|
||||||
|
" \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,148 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"test_p_c.csv saved for Group 1 at 0.class_document/knn_tfidf/1/test_p_c.csv\n",
|
||||||
|
"test_p_c.csv saved for Group 2 at 0.class_document/knn_tfidf/2/test_p_c.csv\n",
|
||||||
|
"test_p_c.csv saved for Group 3 at 0.class_document/knn_tfidf/3/test_p_c.csv\n",
|
||||||
|
"test_p_c.csv saved for Group 4 at 0.class_document/knn_tfidf/4/test_p_c.csv\n",
|
||||||
|
"test_p_c.csv saved for Group 5 at 0.class_document/knn_tfidf/5/test_p_c.csv\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=5: 84.37%\n",
|
||||||
|
"\n",
|
||||||
|
"Final Results:\n",
|
||||||
|
"n_neighbors=1, Average Accuracy: 84.37%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||||
|
"from sklearn.neighbors import NearestNeighbors\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize variables to store overall accuracy results\n",
|
||||||
|
"average_accuracies = []\n",
|
||||||
|
"\n",
|
||||||
|
"# Loop through n_neighbors values from 1 to 52\n",
|
||||||
|
"for n in range(5, 6):\n",
|
||||||
|
" accuracies = [] # Store accuracy for each group\n",
|
||||||
|
"\n",
|
||||||
|
" # Loop through group numbers from 1 to 5\n",
|
||||||
|
" for group_number in range(1, 6):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" # Load the train_all and test CSVs\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = ''\n",
|
||||||
|
" test_csv['c_property'] = ''\n",
|
||||||
|
" test_csv['c_score'] = ''\n",
|
||||||
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" combined_tag_descriptions = train_all_csv['tag_description'].tolist()\n",
|
||||||
|
"\n",
|
||||||
|
" # TfidfVectorizer 사용\n",
|
||||||
|
" vectorizer = TfidfVectorizer(token_pattern=r'\\S+', ngram_range=(1, 1), use_idf=True)\n",
|
||||||
|
" vectorizer.fit(combined_tag_descriptions)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_tfidf_matrix = vectorizer.transform(train_all_csv['tag_description'])\n",
|
||||||
|
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
||||||
|
"\n",
|
||||||
|
" # KNN에서 유클리디안 거리를 이용\n",
|
||||||
|
" knn = NearestNeighbors(n_neighbors=n, metric='cosine', n_jobs=-1)\n",
|
||||||
|
" knn.fit(train_all_tfidf_matrix)\n",
|
||||||
|
"\n",
|
||||||
|
" distances, indices = knn.kneighbors(test_tfidf_matrix)\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things = []\n",
|
||||||
|
" predicted_properties = []\n",
|
||||||
|
" predicted_scores = []\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(len(test_csv)):\n",
|
||||||
|
" neighbor_index = indices[i][0]\n",
|
||||||
|
" distance = distances[i][0]\n",
|
||||||
|
"\n",
|
||||||
|
" neighbor_thing = train_all_csv.iloc[neighbor_index]['thing']\n",
|
||||||
|
" neighbor_property = train_all_csv.iloc[neighbor_index]['property']\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things.append(neighbor_thing)\n",
|
||||||
|
" predicted_properties.append(neighbor_property)\n",
|
||||||
|
"\n",
|
||||||
|
" # 거리 기반으로 유사도 점수 계산\n",
|
||||||
|
" predicted_score = 1 - distance\n",
|
||||||
|
" predicted_scores.append(predicted_score)\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = predicted_things\n",
|
||||||
|
" test_csv['c_property'] = predicted_properties\n",
|
||||||
|
" test_csv['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
|
" accuracies.append(accuracy)\n",
|
||||||
|
"\n",
|
||||||
|
" # n_neighbors가 5일 때, test_csv를 지정된 경로에 저장\n",
|
||||||
|
" if n == 5:\n",
|
||||||
|
" output_path = f'0.class_document/knn_tfidf/{group_number}/test_p_c.csv'\n",
|
||||||
|
" os.makedirs(os.path.dirname(output_path), exist_ok=True) # 폴더가 없을 경우 생성\n",
|
||||||
|
" test_csv.to_csv(output_path, index=False)\n",
|
||||||
|
" print(f\"test_p_c.csv saved for Group {group_number} at {output_path}\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Calculate the average accuracy for the current n_neighbors value\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies.append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with n_neighbors={n}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Print overall results for all n_neighbors values\n",
|
||||||
|
"print(\"\\nFinal Results:\")\n",
|
||||||
|
"for n, avg_accuracy in zip(range(1, 53), average_accuracies):\n",
|
||||||
|
" print(f\"n_neighbors={n}, Average Accuracy: {avg_accuracy:.2f}%\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,174 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=1: 85.69%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=2: 86.04%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=3: 85.85%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=4: 85.88%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=5: 85.84%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=6: 85.81%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=7: 85.84%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=8: 85.86%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=9: 85.84%\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with n_neighbors=10: 85.91%\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "KeyboardInterrupt",
|
||||||
|
"evalue": "",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[1], line 53\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;66;03m# Compute Word2Vec vectors for the train and test data\u001b[39;00m\n\u001b[1;32m 52\u001b[0m train_all_vectors \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([compute_sentence_vector(desc, model, vector_size) \u001b[38;5;28;01mfor\u001b[39;00m desc \u001b[38;5;129;01min\u001b[39;00m train_all_csv[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtag_description\u001b[39m\u001b[38;5;124m'\u001b[39m]])\n\u001b[0;32m---> 53\u001b[0m test_vectors \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([compute_sentence_vector(desc, model, vector_size) \u001b[38;5;28;01mfor\u001b[39;00m desc \u001b[38;5;129;01min\u001b[39;00m test_csv[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtag_description\u001b[39m\u001b[38;5;124m'\u001b[39m]])\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# KNN에서 코사인 거리를 이용\u001b[39;00m\n\u001b[1;32m 56\u001b[0m knn \u001b[38;5;241m=\u001b[39m NearestNeighbors(n_neighbors\u001b[38;5;241m=\u001b[39mn, metric\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124meuclidean\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
||||||
|
"Cell \u001b[0;32mIn[1], line 53\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;66;03m# Compute Word2Vec vectors for the train and test data\u001b[39;00m\n\u001b[1;32m 52\u001b[0m train_all_vectors \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([compute_sentence_vector(desc, model, vector_size) \u001b[38;5;28;01mfor\u001b[39;00m desc \u001b[38;5;129;01min\u001b[39;00m train_all_csv[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtag_description\u001b[39m\u001b[38;5;124m'\u001b[39m]])\n\u001b[0;32m---> 53\u001b[0m test_vectors \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([\u001b[43mcompute_sentence_vector\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdesc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvector_size\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m desc \u001b[38;5;129;01min\u001b[39;00m test_csv[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtag_description\u001b[39m\u001b[38;5;124m'\u001b[39m]])\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# KNN에서 코사인 거리를 이용\u001b[39;00m\n\u001b[1;32m 56\u001b[0m knn \u001b[38;5;241m=\u001b[39m NearestNeighbors(n_neighbors\u001b[38;5;241m=\u001b[39mn, metric\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124meuclidean\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
||||||
|
"Cell \u001b[0;32mIn[1], line 12\u001b[0m, in \u001b[0;36mcompute_sentence_vector\u001b[0;34m(sentence, model, vector_size)\u001b[0m\n\u001b[1;32m 10\u001b[0m word_vectors \u001b[38;5;241m=\u001b[39m [model\u001b[38;5;241m.\u001b[39mwv[word] \u001b[38;5;28;01mfor\u001b[39;00m word \u001b[38;5;129;01min\u001b[39;00m words \u001b[38;5;28;01mif\u001b[39;00m word \u001b[38;5;129;01min\u001b[39;00m model\u001b[38;5;241m.\u001b[39mwv]\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(word_vectors) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43mword_vectors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39mzeros(vector_size)\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 3501\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3502\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mean(axis\u001b[38;5;241m=\u001b[39maxis, dtype\u001b[38;5;241m=\u001b[39mdtype, out\u001b[38;5;241m=\u001b[39mout, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m-> 3504\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_methods\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_mean\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3505\u001b[0m \u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/numpy/core/_methods.py:118\u001b[0m, in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 115\u001b[0m dtype \u001b[38;5;241m=\u001b[39m mu\u001b[38;5;241m.\u001b[39mdtype(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf4\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 116\u001b[0m is_float16_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 118\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mumr_sum\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwhere\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(ret, mu\u001b[38;5;241m.\u001b[39mndarray):\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _no_nep50_warning():\n",
|
||||||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from gensim.models import Word2Vec\n",
|
||||||
|
"from sklearn.neighbors import NearestNeighbors\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to compute the average Word2Vec vector for a sentence\n",
|
||||||
|
"def compute_sentence_vector(sentence, model, vector_size):\n",
|
||||||
|
" words = sentence.split()\n",
|
||||||
|
" word_vectors = [model.wv[word] for word in words if word in model.wv]\n",
|
||||||
|
" if len(word_vectors) > 0:\n",
|
||||||
|
" return np.mean(word_vectors, axis=0)\n",
|
||||||
|
" else:\n",
|
||||||
|
" return np.zeros(vector_size)\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize variables to store overall accuracy results\n",
|
||||||
|
"average_accuracies = []\n",
|
||||||
|
"\n",
|
||||||
|
"# Loop through n_neighbors values from 1 to 52\n",
|
||||||
|
"for n in range(1, 53):\n",
|
||||||
|
" accuracies = [] # Store accuracy for each group\n",
|
||||||
|
"\n",
|
||||||
|
" # Loop through group numbers from 1 to 5\n",
|
||||||
|
" for group_number in range(1, 6):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" # Load the train_all and test CSVs\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = ''\n",
|
||||||
|
" test_csv['c_property'] = ''\n",
|
||||||
|
" test_csv['c_score'] = ''\n",
|
||||||
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" combined_tag_descriptions = train_all_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
|
||||||
|
"\n",
|
||||||
|
" # Train Word2Vec model on combined descriptions\n",
|
||||||
|
" sentences = [desc.split() for desc in combined_tag_descriptions]\n",
|
||||||
|
" vector_size = 200 # You can set the vector size as needed\n",
|
||||||
|
" model = Word2Vec(sentences, vector_size=vector_size, window=3, min_count=1, workers=-1)\n",
|
||||||
|
"\n",
|
||||||
|
" # Compute Word2Vec vectors for the train and test data\n",
|
||||||
|
" train_all_vectors = np.array([compute_sentence_vector(desc, model, vector_size) for desc in train_all_csv['tag_description']])\n",
|
||||||
|
" test_vectors = np.array([compute_sentence_vector(desc, model, vector_size) for desc in test_csv['tag_description']])\n",
|
||||||
|
"\n",
|
||||||
|
" # KNN에서 코사인 거리를 이용\n",
|
||||||
|
" knn = NearestNeighbors(n_neighbors=n, metric='euclidean', n_jobs=-1)\n",
|
||||||
|
" knn.fit(train_all_vectors)\n",
|
||||||
|
"\n",
|
||||||
|
" distances, indices = knn.kneighbors(test_vectors)\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things = []\n",
|
||||||
|
" predicted_properties = []\n",
|
||||||
|
" predicted_scores = []\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(len(test_csv)):\n",
|
||||||
|
" neighbor_index = indices[i][0]\n",
|
||||||
|
" distance = distances[i][0]\n",
|
||||||
|
"\n",
|
||||||
|
" neighbor_thing = train_all_csv.iloc[neighbor_index]['thing']\n",
|
||||||
|
" neighbor_property = train_all_csv.iloc[neighbor_index]['property']\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_things.append(neighbor_thing)\n",
|
||||||
|
" predicted_properties.append(neighbor_property)\n",
|
||||||
|
"\n",
|
||||||
|
" # 거리 기반으로 유사도 점수 계산\n",
|
||||||
|
" predicted_score = 1 - distance\n",
|
||||||
|
" predicted_scores.append(predicted_score)\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = predicted_things\n",
|
||||||
|
" test_csv['c_property'] = predicted_properties\n",
|
||||||
|
" test_csv['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||||
|
" accuracies.append(accuracy)\n",
|
||||||
|
"\n",
|
||||||
|
" # Calculate the average accuracy for the current n_neighbors value\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies.append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with n_neighbors={n}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Print overall results for all n_neighbors values\n",
|
||||||
|
"print(\"\\nFinal Results:\")\n",
|
||||||
|
"for n, avg_accuracy in zip(range(1, 53), average_accuracies):\n",
|
||||||
|
" print(f\"n_neighbors={n}, Average Accuracy: {avg_accuracy:.2f}%\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Running SVM with C=1000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=1000: 89.36%\n",
|
||||||
|
"Running SVM with C=10000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=10000: 89.36%\n",
|
||||||
|
"Running SVM with C=100000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=100000: 89.36%\n",
|
||||||
|
"Running SVM with C=1000000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=1000000: 89.36%\n",
|
||||||
|
"\n",
|
||||||
|
"Final Results for each C value:\n",
|
||||||
|
"C=1000, Average Accuracy: 89.36%\n",
|
||||||
|
"C=10000, Average Accuracy: 89.36%\n",
|
||||||
|
"C=100000, Average Accuracy: 89.36%\n",
|
||||||
|
"C=1000000, Average Accuracy: 89.36%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
||||||
|
"from sklearn.svm import SVC\n",
|
||||||
|
"import os\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from joblib import Parallel, delayed\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize variables to store overall accuracy results\n",
|
||||||
|
"average_accuracies = {}\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to process each group (parallelized later)\n",
|
||||||
|
"def process_group(C_value, group_number):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
" # Load the train_all and test CSVs\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = ''\n",
|
||||||
|
" test_csv['c_property'] = ''\n",
|
||||||
|
" test_csv['c_score'] = ''\n",
|
||||||
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" combined_tag_descriptions = train_all_csv['tag_description'].tolist()\n",
|
||||||
|
"\n",
|
||||||
|
" # BoW를 Boolean 방식으로 변환\n",
|
||||||
|
" vectorizer = CountVectorizer(token_pattern=r'\\S+', binary=True)\n",
|
||||||
|
" vectorizer.fit(combined_tag_descriptions)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_bow_matrix = vectorizer.transform(train_all_csv['tag_description']).toarray().astype(bool) # bool로 변환\n",
|
||||||
|
" test_bow_matrix = vectorizer.transform(test_csv['tag_description']).toarray().astype(bool)\n",
|
||||||
|
"\n",
|
||||||
|
" # SVM 모델 학습 및 예측\n",
|
||||||
|
" svm_model_thing = SVC(kernel='linear', probability=True, C=C_value)\n",
|
||||||
|
" svm_model_property = SVC(kernel='linear', probability=True, C=C_value)\n",
|
||||||
|
"\n",
|
||||||
|
" # SVM을 이용하여 'thing' 및 'property' 예측 모델 학습\n",
|
||||||
|
" svm_model_thing.fit(train_all_bow_matrix, train_all_csv['thing'])\n",
|
||||||
|
" svm_model_property.fit(train_all_bow_matrix, train_all_csv['property'])\n",
|
||||||
|
"\n",
|
||||||
|
" # 'thing' 및 'property' 예측\n",
|
||||||
|
" predicted_things = svm_model_thing.predict(test_bow_matrix)\n",
|
||||||
|
" predicted_properties = svm_model_property.predict(test_bow_matrix)\n",
|
||||||
|
" \n",
|
||||||
|
" predicted_scores_thing = svm_model_thing.predict_proba(test_bow_matrix)[:, 1] # 'thing'의 예측 확률 점수\n",
|
||||||
|
" predicted_scores_property = svm_model_property.predict_proba(test_bow_matrix)[:, 1] # 'property'의 예측 확률 점수\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_scores = (predicted_scores_thing + predicted_scores_property) / 2 # 평균 점수로 결합\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = predicted_things\n",
|
||||||
|
" test_csv['c_property'] = predicted_properties\n",
|
||||||
|
" test_csv['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100 if mdm_true_count > 0 else 0\n",
|
||||||
|
" return accuracy\n",
|
||||||
|
"\n",
|
||||||
|
"# C 값들에 대해 실험할 값 설정 (log 스케일)\n",
|
||||||
|
"C_values = [0.01, 0.1, 1, 10, 100]\n",
|
||||||
|
"C_values = [1000, 10000, 100000, 1000000]\n",
|
||||||
|
"# 각 C 값에 대해 실험\n",
|
||||||
|
"for C_value in C_values:\n",
|
||||||
|
" print(f\"Running SVM with C={C_value}\")\n",
|
||||||
|
" average_accuracies[C_value] = []\n",
|
||||||
|
"\n",
|
||||||
|
" # Parallel processing for groups\n",
|
||||||
|
" results = Parallel(n_jobs=-1)(delayed(process_group)(C_value, group_number) for group_number in range(1, 6))\n",
|
||||||
|
"\n",
|
||||||
|
" # Filter out None results (in case of missing files)\n",
|
||||||
|
" accuracies = [result for result in results if result is not None]\n",
|
||||||
|
"\n",
|
||||||
|
" if accuracies:\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies[C_value].append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with C={C_value}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Print overall results for all C values\n",
|
||||||
|
"print(\"\\nFinal Results for each C value:\")\n",
|
||||||
|
"for C_value, accuracies in average_accuracies.items():\n",
|
||||||
|
" avg_acc = np.mean(accuracies)\n",
|
||||||
|
" print(f\"C={C_value}, Average Accuracy: {avg_acc:.2f}%\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Running SVM with C=1000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=1000: 89.87%\n",
|
||||||
|
"Running SVM with C=10000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=10000: 89.33%\n",
|
||||||
|
"Running SVM with C=100000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=100000: 89.18%\n",
|
||||||
|
"Running SVM with C=1000000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=1000000: 89.18%\n",
|
||||||
|
"\n",
|
||||||
|
"Final Results for each C value:\n",
|
||||||
|
"C=1000, Average Accuracy: 89.87%\n",
|
||||||
|
"C=10000, Average Accuracy: 89.33%\n",
|
||||||
|
"C=100000, Average Accuracy: 89.18%\n",
|
||||||
|
"C=1000000, Average Accuracy: 89.18%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||||
|
"from sklearn.svm import SVC\n",
|
||||||
|
"import os\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from joblib import Parallel, delayed\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize variables to store overall accuracy results\n",
|
||||||
|
"average_accuracies = {}\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to process each group (parallelized later)\n",
|
||||||
|
"def process_group(C_value, group_number):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
" # Load the train_all and test CSVs\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = ''\n",
|
||||||
|
" test_csv['c_property'] = ''\n",
|
||||||
|
" test_csv['c_score'] = ''\n",
|
||||||
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" combined_tag_descriptions = train_all_csv['tag_description'].tolist()\n",
|
||||||
|
"\n",
|
||||||
|
" # TF-IDF 벡터화\n",
|
||||||
|
" vectorizer = TfidfVectorizer(token_pattern=r'\\S+')\n",
|
||||||
|
" vectorizer.fit(combined_tag_descriptions)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_tfidf_matrix = vectorizer.transform(train_all_csv['tag_description']).toarray() # TF-IDF로 변환\n",
|
||||||
|
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description']).toarray()\n",
|
||||||
|
"\n",
|
||||||
|
" # SVM 모델 학습 및 예측\n",
|
||||||
|
" svm_model_thing = SVC(kernel='linear', probability=True, C=C_value)\n",
|
||||||
|
" svm_model_property = SVC(kernel='linear', probability=True, C=C_value)\n",
|
||||||
|
"\n",
|
||||||
|
" # SVM을 이용하여 'thing' 및 'property' 예측 모델 학습\n",
|
||||||
|
" svm_model_thing.fit(train_all_tfidf_matrix, train_all_csv['thing'])\n",
|
||||||
|
" svm_model_property.fit(train_all_tfidf_matrix, train_all_csv['property'])\n",
|
||||||
|
"\n",
|
||||||
|
" # 'thing' 및 'property' 예측\n",
|
||||||
|
" predicted_things = svm_model_thing.predict(test_tfidf_matrix)\n",
|
||||||
|
" predicted_properties = svm_model_property.predict(test_tfidf_matrix)\n",
|
||||||
|
" \n",
|
||||||
|
" predicted_scores_thing = svm_model_thing.predict_proba(test_tfidf_matrix)[:, 1] # 'thing'의 예측 확률 점수\n",
|
||||||
|
" predicted_scores_property = svm_model_property.predict_proba(test_tfidf_matrix)[:, 1] # 'property'의 예측 확률 점수\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_scores = (predicted_scores_thing + predicted_scores_property) / 2 # 평균 점수로 결합\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = predicted_things\n",
|
||||||
|
" test_csv['c_property'] = predicted_properties\n",
|
||||||
|
" test_csv['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100 if mdm_true_count > 0 else 0\n",
|
||||||
|
" return accuracy\n",
|
||||||
|
"\n",
|
||||||
|
"# C 값들에 대해 실험할 값 설정 (log 스케일)\n",
|
||||||
|
"C_values = [0.1, 1, 10, 100]\n",
|
||||||
|
"C_values = [1000, 10000, 100000, 1000000]\n",
|
||||||
|
"# 각 C 값에 대해 실험\n",
|
||||||
|
"for C_value in C_values:\n",
|
||||||
|
" print(f\"Running SVM with C={C_value}\")\n",
|
||||||
|
" average_accuracies[C_value] = []\n",
|
||||||
|
"\n",
|
||||||
|
" # Parallel processing for groups\n",
|
||||||
|
" results = Parallel(n_jobs=-1)(delayed(process_group)(C_value, group_number) for group_number in range(1, 6))\n",
|
||||||
|
"\n",
|
||||||
|
" # Filter out None results (in case of missing files)\n",
|
||||||
|
" accuracies = [result for result in results if result is not None]\n",
|
||||||
|
"\n",
|
||||||
|
" if accuracies:\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies[C_value].append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with C={C_value}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Print overall results for all C values\n",
|
||||||
|
"print(\"\\nFinal Results for each C value:\")\n",
|
||||||
|
"for C_value, accuracies in average_accuracies.items():\n",
|
||||||
|
" avg_acc = np.mean(accuracies)\n",
|
||||||
|
" print(f\"C={C_value}, Average Accuracy: {avg_acc:.2f}%\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,161 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Running SVM with C=10000000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=10000000: 86.77%\n",
|
||||||
|
"Running SVM with C=100000000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=100000000: 86.64%\n",
|
||||||
|
"Running SVM with C=1000000000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=1000000000: 86.68%\n",
|
||||||
|
"Running SVM with C=10000000000\n",
|
||||||
|
"Average Accuracy (MDM=True) across all groups with C=10000000000: 86.90%\n",
|
||||||
|
"\n",
|
||||||
|
"Final Results for each C value:\n",
|
||||||
|
"C=10000000, Average Accuracy: 86.77%\n",
|
||||||
|
"C=100000000, Average Accuracy: 86.64%\n",
|
||||||
|
"C=1000000000, Average Accuracy: 86.68%\n",
|
||||||
|
"C=10000000000, Average Accuracy: 86.90%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from gensim.models import Word2Vec\n",
|
||||||
|
"from sklearn.svm import SVC\n",
|
||||||
|
"from sklearn.metrics import pairwise_distances\n",
|
||||||
|
"import os\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from joblib import Parallel, delayed\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to compute the average Word2Vec vector for a sentence\n",
|
||||||
|
"def compute_sentence_vector(sentence, model, vector_size):\n",
|
||||||
|
" words = sentence.split()\n",
|
||||||
|
" word_vectors = [model.wv[word] for word in words if word in model.wv]\n",
|
||||||
|
" if len(word_vectors) > 0:\n",
|
||||||
|
" return np.mean(word_vectors, axis=0)\n",
|
||||||
|
" else:\n",
|
||||||
|
" return np.zeros(vector_size)\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize variables to store overall accuracy results\n",
|
||||||
|
"average_accuracies = {}\n",
|
||||||
|
"\n",
|
||||||
|
"# Function to process each group (parallelized later)\n",
|
||||||
|
"def process_group(C_value, group_number):\n",
|
||||||
|
" train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
" test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"\n",
|
||||||
|
" if not os.path.exists(test_path):\n",
|
||||||
|
" print(f\"Test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
" # Load the train_all and test CSVs\n",
|
||||||
|
" train_all_csv = pd.read_csv(train_all_path, low_memory=False)\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
|
"\n",
|
||||||
|
" train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')\n",
|
||||||
|
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = ''\n",
|
||||||
|
" test_csv['c_property'] = ''\n",
|
||||||
|
" test_csv['c_score'] = ''\n",
|
||||||
|
" test_csv['c_duplicate'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
" combined_tag_descriptions = train_all_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
|
||||||
|
" sentences = [desc.split() for desc in combined_tag_descriptions]\n",
|
||||||
|
" \n",
|
||||||
|
" vector_size = 200 # 벡터 크기 설정\n",
|
||||||
|
" model = Word2Vec(sentences, vector_size=vector_size, window=3, min_count=1, workers=-1)\n",
|
||||||
|
"\n",
|
||||||
|
" # Train data vectors\n",
|
||||||
|
" train_all_vectors = np.array([compute_sentence_vector(desc, model, vector_size) for desc in train_all_csv['tag_description']])\n",
|
||||||
|
" # Test data vectors\n",
|
||||||
|
" test_vectors = np.array([compute_sentence_vector(desc, model, vector_size) for desc in test_csv['tag_description']])\n",
|
||||||
|
"\n",
|
||||||
|
" # SVM 모델 학습 및 예측\n",
|
||||||
|
" svm_model_thing = SVC(kernel='linear', probability=True, C=C_value)\n",
|
||||||
|
" svm_model_property = SVC(kernel='linear', probability=True, C=C_value)\n",
|
||||||
|
"\n",
|
||||||
|
" # SVM을 이용하여 'thing' 및 'property' 예측 모델 학습\n",
|
||||||
|
" svm_model_thing.fit(train_all_vectors, train_all_csv['thing'])\n",
|
||||||
|
" svm_model_property.fit(train_all_vectors, train_all_csv['property'])\n",
|
||||||
|
"\n",
|
||||||
|
" # 'thing' 및 'property' 예측\n",
|
||||||
|
" predicted_things = svm_model_thing.predict(test_vectors)\n",
|
||||||
|
" predicted_properties = svm_model_property.predict(test_vectors)\n",
|
||||||
|
" \n",
|
||||||
|
" predicted_scores_thing = svm_model_thing.predict_proba(test_vectors)[:, 1] # 'thing'의 예측 확률 점수\n",
|
||||||
|
" predicted_scores_property = svm_model_property.predict_proba(test_vectors)[:, 1] # 'property'의 예측 확률 점수\n",
|
||||||
|
"\n",
|
||||||
|
" predicted_scores = (predicted_scores_thing + predicted_scores_property) / 2 # 평균 점수로 결합\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['c_thing'] = predicted_things\n",
|
||||||
|
" test_csv['c_property'] = predicted_properties\n",
|
||||||
|
" test_csv['c_score'] = predicted_scores\n",
|
||||||
|
"\n",
|
||||||
|
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||||
|
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||||
|
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||||
|
"\n",
|
||||||
|
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||||
|
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100 if mdm_true_count > 0 else 0\n",
|
||||||
|
" return accuracy\n",
|
||||||
|
"\n",
|
||||||
|
"# C 값들에 대해 실험할 값 설정 (log 스케일)\n",
|
||||||
|
"C_values = [0.1, 1, 10, 100]\n",
|
||||||
|
"C_values = [1000, 10000, 100000, 1000000]\n",
|
||||||
|
"C_values = [10000000, 100000000, 1000000000, 10000000000]\n",
|
||||||
|
"\n",
|
||||||
|
"# 각 C 값에 대해 실험\n",
|
||||||
|
"for C_value in C_values:\n",
|
||||||
|
" print(f\"Running SVM with C={C_value}\")\n",
|
||||||
|
" average_accuracies[C_value] = []\n",
|
||||||
|
"\n",
|
||||||
|
" # Parallel processing for groups\n",
|
||||||
|
" results = Parallel(n_jobs=-1)(delayed(process_group)(C_value, group_number) for group_number in range(1, 6))\n",
|
||||||
|
"\n",
|
||||||
|
" # Filter out None results (in case of missing files)\n",
|
||||||
|
" accuracies = [result for result in results if result is not None]\n",
|
||||||
|
"\n",
|
||||||
|
" if accuracies:\n",
|
||||||
|
" average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||||
|
" average_accuracies[C_value].append(average_accuracy)\n",
|
||||||
|
" print(f\"Average Accuracy (MDM=True) across all groups with C={C_value}: {average_accuracy:.2f}%\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Print overall results for all C values\n",
|
||||||
|
"print(\"\\nFinal Results for each C value:\")\n",
|
||||||
|
"for C_value, accuracies in average_accuracies.items():\n",
|
||||||
|
" avg_acc = np.mean(accuracies)\n",
|
||||||
|
" print(f\"C={C_value}, Average Accuracy: {avg_acc:.2f}%\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import os\n",
|
||||||
|
"group_number = 5\n",
|
||||||
|
"class_model = 'distilbert'\n",
|
||||||
|
"gen_model = 't5-tiny'\n",
|
||||||
|
"# 경로 설정\n",
|
||||||
|
"test_path = f'../../translation/0.result/{group_number}/test_p.csv'\n",
|
||||||
|
"class_path = f'0.class_document/{class_model}/{group_number}/test_p_c.csv'\n",
|
||||||
|
"output_path = f'0.class_document/{class_model}/{gen_model}/{group_number}/test_p_c.csv'\n",
|
||||||
|
"\n",
|
||||||
|
"# 파일 읽기\n",
|
||||||
|
"test_df = pd.read_csv(test_path)\n",
|
||||||
|
"class_df = pd.read_csv(class_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# 필요한 필드 선택\n",
|
||||||
|
"fields_to_copy = ['c_thing', 'c_property', 'c_score', 'cthing_correct', 'cproperty_correct', 'ctp_correct']\n",
|
||||||
|
"class_df_subset = class_df[fields_to_copy]\n",
|
||||||
|
"\n",
|
||||||
|
"# test_path에 필드 복사\n",
|
||||||
|
"merged_df = pd.concat([test_df, class_df_subset], axis=1)\n",
|
||||||
|
"\n",
|
||||||
|
"# 결과 저장\n",
|
||||||
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||||||
|
"merged_df.to_csv(output_path, index=False)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -2,29 +2,48 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"ename": "KeyError",
|
"name": "stdout",
|
||||||
"evalue": "'p_correct'",
|
"output_type": "stream",
|
||||||
"output_type": "error",
|
"text": [
|
||||||
"traceback": [
|
"Processing group 1...\n",
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
"Total updates where p_correct is False and ctp_correct is True (group 1): 55\n",
|
||||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
"Number of rows with duplicates in the same ships_idx (group 1): 34\n",
|
||||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
"Number of rows without duplicates in the same ships_idx (group 1): 21\n",
|
||||||
"File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
"Number of updates made (group 1): 427\n",
|
||||||
"File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
"Updated test CSV saved to 0.class_document/distilbert/1/test_p_c_r.csv\n",
|
||||||
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
"Refine CSV saved to 0.class_document/distilbert/1/refine.csv\n",
|
||||||
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
"Processing group 2...\n",
|
||||||
"\u001b[0;31mKeyError\u001b[0m: 'p_correct'",
|
"Total updates where p_correct is False and ctp_correct is True (group 2): 63\n",
|
||||||
"\nThe above exception was the direct cause of the following exception:\n",
|
"Number of rows with duplicates in the same ships_idx (group 2): 21\n",
|
||||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
"Number of rows without duplicates in the same ships_idx (group 2): 42\n",
|
||||||
"Cell \u001b[0;32mIn[11], line 22\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m index, row \u001b[38;5;129;01min\u001b[39;00m test_csv\u001b[38;5;241m.\u001b[39miterrows():\n\u001b[0;32m---> 22\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mrow\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mp_correct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;129;01mand\u001b[39;00m row[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mctp_correct\u001b[39m\u001b[38;5;124m'\u001b[39m]:\n\u001b[1;32m 23\u001b[0m update_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# Increment the counter\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# Check for duplicates within the same ships_idx\u001b[39;00m\n",
|
"Number of updates made (group 2): 225\n",
|
||||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/series.py:1121\u001b[0m, in \u001b[0;36mSeries.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_values[key]\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m key_is_scalar:\n\u001b[0;32m-> 1121\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;66;03m# Convert generator to list before going through hashable part\u001b[39;00m\n\u001b[1;32m 1124\u001b[0m \u001b[38;5;66;03m# (We will iterate through the generator there to check for slices)\u001b[39;00m\n\u001b[1;32m 1125\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_iterator(key):\n",
|
"Updated test CSV saved to 0.class_document/distilbert/2/test_p_c_r.csv\n",
|
||||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/series.py:1237\u001b[0m, in \u001b[0;36mSeries._get_value\u001b[0;34m(self, label, takeable)\u001b[0m\n\u001b[1;32m 1234\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_values[label]\n\u001b[1;32m 1236\u001b[0m \u001b[38;5;66;03m# Similar to Index.get_value, but we do not fall back to positional\u001b[39;00m\n\u001b[0;32m-> 1237\u001b[0m loc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1239\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(loc):\n\u001b[1;32m 1240\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_values[loc]\n",
|
"Refine CSV saved to 0.class_document/distilbert/2/refine.csv\n",
|
||||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3807\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m 3810\u001b[0m ):\n\u001b[1;32m 3811\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 3814\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3815\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3816\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3817\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n",
|
"Processing group 3...\n",
|
||||||
"\u001b[0;31mKeyError\u001b[0m: 'p_correct'"
|
"Total updates where p_correct is False and ctp_correct is True (group 3): 32\n",
|
||||||
|
"Number of rows with duplicates in the same ships_idx (group 3): 10\n",
|
||||||
|
"Number of rows without duplicates in the same ships_idx (group 3): 22\n",
|
||||||
|
"Number of updates made (group 3): 343\n",
|
||||||
|
"Updated test CSV saved to 0.class_document/distilbert/3/test_p_c_r.csv\n",
|
||||||
|
"Refine CSV saved to 0.class_document/distilbert/3/refine.csv\n",
|
||||||
|
"Processing group 4...\n",
|
||||||
|
"Total updates where p_correct is False and ctp_correct is True (group 4): 37\n",
|
||||||
|
"Number of rows with duplicates in the same ships_idx (group 4): 25\n",
|
||||||
|
"Number of rows without duplicates in the same ships_idx (group 4): 12\n",
|
||||||
|
"Number of updates made (group 4): 596\n",
|
||||||
|
"Updated test CSV saved to 0.class_document/distilbert/4/test_p_c_r.csv\n",
|
||||||
|
"Refine CSV saved to 0.class_document/distilbert/4/refine.csv\n",
|
||||||
|
"Processing group 5...\n",
|
||||||
|
"Total updates where p_correct is False and ctp_correct is True (group 5): 40\n",
|
||||||
|
"Number of rows with duplicates in the same ships_idx (group 5): 19\n",
|
||||||
|
"Number of rows without duplicates in the same ships_idx (group 5): 21\n",
|
||||||
|
"Number of updates made (group 5): 379\n",
|
||||||
|
"Updated test CSV saved to 0.class_document/distilbert/5/test_p_c_r.csv\n",
|
||||||
|
"Refine CSV saved to 0.class_document/distilbert/5/refine.csv\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -33,90 +52,78 @@
|
||||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||||
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
|
"import re\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Set the group number\n",
|
"model = \"distilbert\"\n",
|
||||||
"group_number = 1 # Change this to the desired group number\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the CSV files from the specified group\n",
|
"for group_number in range(1, 6): # Group 1 to 5\n",
|
||||||
"sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
|
" print(f\"Processing group {group_number}...\")\n",
|
||||||
"test_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
|
" # Load test CSV for the current group\n",
|
||||||
"test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
" test_path = f'0.class_document/{model}/t5-tiny/{group_number}/test_p_c.csv'\n",
|
||||||
|
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"update_count = 0\n",
|
" # Initialize counters\n",
|
||||||
"duplicate_count = 0\n",
|
" update_count = 0\n",
|
||||||
"non_duplicate_count = 0\n",
|
" duplicate_count = 0\n",
|
||||||
|
" non_duplicate_count = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
" # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
||||||
"for index, row in test_csv.iterrows():\n",
|
" for index, row in test_csv.iterrows():\n",
|
||||||
" if not row['p_correct'] and row['ctp_correct']:\n",
|
" if not row['p_correct'] and row['ctp_correct']:\n",
|
||||||
" update_count += 1 # Increment the counter\n",
|
" update_count += 1 # Increment the counter\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Check for duplicates within the same ships_idx\n",
|
" # Check for duplicates within the same ships_idx\n",
|
||||||
" same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n",
|
" same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n",
|
||||||
" (test_csv['p_thing'] == row['c_thing']) &\n",
|
" (test_csv['p_thing'] == row['c_thing']) &\n",
|
||||||
" (test_csv['p_property'] == row['c_property'])]\n",
|
" (test_csv['p_property'] == row['c_property'])]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if len(same_idx_rows) > 0:\n",
|
" if len(same_idx_rows) > 0:\n",
|
||||||
" duplicate_count += 1\n",
|
" duplicate_count += 1\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" non_duplicate_count += 1\n",
|
" non_duplicate_count += 1\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the results\n",
|
" # Print the results for the current group\n",
|
||||||
"print(f\"Total updates where p_correct is False and ctp_correct is True: {update_count}\")\n",
|
" print(f\"Total updates where p_correct is False and ctp_correct is True (group {group_number}): {update_count}\")\n",
|
||||||
"print(f\"Number of rows with duplicates in the same ships_idx: {duplicate_count}\")\n",
|
" print(f\"Number of rows with duplicates in the same ships_idx (group {group_number}): {duplicate_count}\")\n",
|
||||||
"print(f\"Number of rows without duplicates in the same ships_idx: {non_duplicate_count}\")\n",
|
" print(f\"Number of rows without duplicates in the same ships_idx (group {group_number}): {non_duplicate_count}\")\n",
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Number of updates made: 45\n",
|
|
||||||
"Updated test CSV saved to 0.class_document/1/test_p_c_r.csv\n",
|
|
||||||
"Refine CSV saved to refine.csv\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"update_count = 0\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Initialize a list to hold rows that meet the conditions\n",
|
" # Initialize a list to hold rows that meet the conditions for refinement\n",
|
||||||
"refine_rows = []\n",
|
" refine_rows = []\n",
|
||||||
|
" update_count = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
" # Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
||||||
"for index, row in test_csv.iterrows():\n",
|
" for index, row in test_csv.iterrows():\n",
|
||||||
" if (not row['p_MDM'] and row['c_score'] >= 0.9 and \n",
|
" if (not row['p_MDM'] and row['c_score'] >= 0.91 and \n",
|
||||||
" (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n",
|
" (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n",
|
||||||
" test_csv.at[index, 'p_thing'] = row['c_thing']\n",
|
|
||||||
" test_csv.at[index, 'p_property'] = row['c_property']\n",
|
|
||||||
" test_csv.at[index, 'p_MDM'] = True\n",
|
|
||||||
" update_count += 1 # Increment the counter\n",
|
|
||||||
" refine_rows.append(row) # Add the row to the refine list\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Convert the list of refine rows into a DataFrame\n",
|
" test_csv.at[index, 'p_thing'] = row['c_thing']\n",
|
||||||
"refine_df = pd.DataFrame(refine_rows)\n",
|
" test_csv.at[index, 'p_property'] = row['c_property']\n",
|
||||||
|
" test_csv.at[index, 'p_MDM'] = True\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save the refine DataFrame to a CSV file\n",
|
" updated_p_thing = test_csv.at[index, 'p_thing']\n",
|
||||||
"refine_output_path = f'refine.csv'\n",
|
" updated_p_property = test_csv.at[index, 'p_property']\n",
|
||||||
"refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n",
|
" p_pattern = re.sub(r'\\d', '#', updated_p_thing) + \" \" + re.sub(r'\\d', '#', updated_p_property)\n",
|
||||||
|
" test_csv.at[index, 'p_pattern'] = p_pattern\n",
|
||||||
|
" update_count += 1 # Increment the counter\n",
|
||||||
|
" refine_rows.append(row) # Add the row to the refine list\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the number of updates made\n",
|
" # Convert the list of refine rows into a DataFrame\n",
|
||||||
"print(f\"Number of updates made: {update_count}\")\n",
|
" refine_df = pd.DataFrame(refine_rows)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save the updated test CSV\n",
|
" # Save the refine DataFrame to a CSV file for the current group\n",
|
||||||
"output_file_path = f'0.class_document/{group_number}/test_p_c_r.csv'\n",
|
" refine_output_path = f'0.class_document/{model}/{group_number}/refine.csv'\n",
|
||||||
"test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
" refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
" \n",
|
"\n",
|
||||||
"print(f\"Updated test CSV saved to {output_file_path}\")\n",
|
" # Print the number of updates made\n",
|
||||||
"print(f\"Refine CSV saved to {refine_output_path}\")\n"
|
" print(f\"Number of updates made (group {group_number}): {update_count}\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Save the updated test CSV for the current group\n",
|
||||||
|
" output_file_path = f'0.class_document/{model}/{group_number}/test_p_c_r.csv'\n",
|
||||||
|
" test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Updated test CSV saved to {output_file_path}\")\n",
|
||||||
|
" print(f\"Refine CSV saved to {refine_output_path}\")\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The file with the updated p_dup and p_map columns has been saved: 0.class_document/knn_tfidf/1/test_p_c_r.csv\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"group_number = 1\n",
|
||||||
|
"method_name='knn_tfidf'\n",
|
||||||
|
"# Read the test file\n",
|
||||||
|
"test_path = f'0.class_document/{method_name}/{group_number}/test_p_c_r.csv'\n",
|
||||||
|
"df = pd.read_csv(test_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# Concatenate p_thing and p_property into p_tp in the test data\n",
|
||||||
|
"df['p_tp'] = df['p_thing'] + \" \" + df['p_property']\n",
|
||||||
|
"\n",
|
||||||
|
"# Read the train_all file\n",
|
||||||
|
"train_all_path = f'../../data_preprocess/dataset/{group_number}/train_all.csv'\n",
|
||||||
|
"train_all_df = pd.read_csv(train_all_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# Concatenate thing and property into tp in the train_all data\n",
|
||||||
|
"train_all_df['tp'] = train_all_df['thing'] + \" \" + train_all_df['property']\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize the p_map column in the test data\n",
|
||||||
|
"df['p_map'] = 0\n",
|
||||||
|
"\n",
|
||||||
|
"# Group by ships_idx and then group p_tp within each ships_idx group\n",
|
||||||
|
"grouped = df.groupby('ships_idx')['p_tp']\n",
|
||||||
|
"\n",
|
||||||
|
"# Iterate through each ships_idx group\n",
|
||||||
|
"for ships_idx, group in grouped:\n",
|
||||||
|
" # Count the occurrences of each p_tp within the test group\n",
|
||||||
|
" p_tp_counts = group.value_counts()\n",
|
||||||
|
" \n",
|
||||||
|
" # Assign the count as an integer to p_dup for rows with the corresponding p_tp within the group\n",
|
||||||
|
" for p_tp, count in p_tp_counts.items():\n",
|
||||||
|
" # Update p_dup\n",
|
||||||
|
" df.loc[(df['ships_idx'] == ships_idx) & (df['p_tp'] == p_tp), 'p_dup'] = int(count)\n",
|
||||||
|
" \n",
|
||||||
|
" # Calculate p_map by counting matching tp in train_all_df\n",
|
||||||
|
" p_map_count = train_all_df['tp'].eq(p_tp).sum()\n",
|
||||||
|
" df.loc[(df['ships_idx'] == ships_idx) & (df['p_tp'] == p_tp), 'p_map'] = int(p_map_count)\n",
|
||||||
|
"\n",
|
||||||
|
"# Save the modified DataFrame\n",
|
||||||
|
"output_path = f'0.class_document/{method_name}/{group_number}/test_p_c_r.csv'\n",
|
||||||
|
"df.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"The file with the updated p_dup and p_map columns has been saved:\", output_path)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -1,114 +0,0 @@
|
||||||
import pandas as pd
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
from tqdm import tqdm
|
|
||||||
import os
|
|
||||||
|
|
||||||
group_number = 1
|
|
||||||
# Load the CSV files
|
|
||||||
test_path = f'post_process/tfidf_class/0.class_document/{group_number}/test_p_c.csv'
|
|
||||||
test_path = f'post_process/tfidf_class/0.class_document/{group_number}/test_p_c_r.csv'
|
|
||||||
ship_data_list_reference_doc_file_path = f'post_process/tfidf_class/0.class_document/{group_number}/sdl_class_rdoc.csv'
|
|
||||||
|
|
||||||
test_csv = pd.read_csv(test_path, low_memory=False)
|
|
||||||
sdl_rdoc = pd.read_csv(ship_data_list_reference_doc_file_path)
|
|
||||||
|
|
||||||
# Initialize new columns in test_csv
|
|
||||||
test_csv['s_score'] = -1
|
|
||||||
test_csv['s_thing'] = ''
|
|
||||||
test_csv['s_property'] = ''
|
|
||||||
test_csv['s_correct'] = False
|
|
||||||
|
|
||||||
duplicate_filtered = test_csv[(test_csv['p_MDM'] == True)].copy()
|
|
||||||
|
|
||||||
# Create a mapping from thing/property to reference_doc
|
|
||||||
thing_property_to_reference_doc = sdl_rdoc.set_index(['thing', 'property'])['tag_description'].to_dict()
|
|
||||||
|
|
||||||
# Calculate s_score for duplicate rows
|
|
||||||
for ships_idx, group in tqdm(duplicate_filtered.groupby('ships_idx'), desc="Processing duplicates"):
|
|
||||||
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
|
||||||
sub_group = sub_group.copy()
|
|
||||||
tag_descriptions = sub_group['tag_description'].tolist()
|
|
||||||
|
|
||||||
# Get the reference document for the corresponding p_thing and p_property
|
|
||||||
reference_doc = thing_property_to_reference_doc.get((p_thing, p_property), '')
|
|
||||||
|
|
||||||
if reference_doc:
|
|
||||||
# Combine the tag_descriptions and the reference_doc for fit_transform
|
|
||||||
combined_descriptions = tag_descriptions + [reference_doc]
|
|
||||||
|
|
||||||
# Create a new TF-IDF Vectorizer for this specific group
|
|
||||||
vectorizer = TfidfVectorizer(
|
|
||||||
token_pattern=r'\S+',
|
|
||||||
norm='l2', # Use L2 normalization
|
|
||||||
ngram_range=(1, 7), # Use both unigrams and bigrams
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fit and transform the combined descriptions
|
|
||||||
tfidf_matrix = vectorizer.fit_transform(combined_descriptions)
|
|
||||||
|
|
||||||
# Separate the test_tfidf_matrix and reference_vector
|
|
||||||
test_tfidf_matrix = tfidf_matrix[:-1] # All but the last one
|
|
||||||
reference_vector = tfidf_matrix[-1] # The last one
|
|
||||||
|
|
||||||
# Calculate the cosine similarity between the test descriptions and the reference_doc
|
|
||||||
sub_group['s_score'] = cosine_similarity(test_tfidf_matrix, reference_vector).flatten()
|
|
||||||
else:
|
|
||||||
sub_group['s_score'] = 0
|
|
||||||
|
|
||||||
# Update the s_score values back into the original test_csv
|
|
||||||
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
|
||||||
|
|
||||||
for ships_idx, group in tqdm(duplicate_filtered.groupby('ships_idx'), desc="Processing duplicates"):
|
|
||||||
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
|
||||||
if (sub_group['s_score'] == -1).any():
|
|
||||||
best_index = sub_group.index.min()
|
|
||||||
else:
|
|
||||||
# Find the index of the row with the highest s_score
|
|
||||||
best_index = sub_group['s_score'].idxmax()
|
|
||||||
row_position = sub_group.index.get_loc(best_index)
|
|
||||||
|
|
||||||
# Assign s_thing and s_property only to the row with the highest s_score
|
|
||||||
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, 'p_thing']
|
|
||||||
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, 'p_property']
|
|
||||||
|
|
||||||
# Now, update the original test_csv with the changes made in duplicate_filtered
|
|
||||||
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
|
||||||
|
|
||||||
# Calculate s_correct
|
|
||||||
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
|
||||||
(test_csv['property'] == test_csv['s_property']) &
|
|
||||||
(test_csv['MDM']))
|
|
||||||
|
|
||||||
# Calculate the percentage of correct s_thing and s_property
|
|
||||||
mdm_true_count = test_csv['MDM'].sum()
|
|
||||||
s_correct_count = test_csv['s_correct'].sum()
|
|
||||||
s_correct_percentage = (s_correct_count / mdm_true_count) * 100
|
|
||||||
|
|
||||||
print(f"s_correct count: {s_correct_count}")
|
|
||||||
print(f"MDM true count: {mdm_true_count}")
|
|
||||||
print(f"s_correct percentage: {s_correct_percentage:.2f}%")
|
|
||||||
|
|
||||||
|
|
||||||
# Save the updated DataFrame to a new CSV file
|
|
||||||
output_path = test_path = f'post_process/0.result/{group_number}/test_s.csv'
|
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
||||||
test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')
|
|
||||||
|
|
||||||
print(f"Updated data saved to {output_path}")
|
|
||||||
|
|
||||||
# Check for duplicates in s_thing and s_property within each ships_idx
|
|
||||||
print("\nShips_idx with duplicate s_thing and s_property:")
|
|
||||||
duplicate_ships_idx = []
|
|
||||||
|
|
||||||
for ships_idx, group in test_csv.groupby('ships_idx'):
|
|
||||||
# Exclude rows with empty s_thing or s_property
|
|
||||||
non_empty_group = group[(group['s_thing'] != '') & (group['s_property'] != '')]
|
|
||||||
duplicate_entries = non_empty_group[non_empty_group.duplicated(subset=['s_thing', 's_property'], keep=False)]
|
|
||||||
if not duplicate_entries.empty:
|
|
||||||
duplicate_ships_idx.append(ships_idx)
|
|
||||||
print(f"Ships_idx: {ships_idx}")
|
|
||||||
print(duplicate_entries[['s_thing', 's_property']])
|
|
||||||
|
|
||||||
if not duplicate_ships_idx:
|
|
||||||
print("No duplicates found.")
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
k_accuracies = []
|
||||||
|
|
||||||
|
p_thing_str = 'c_thing'
|
||||||
|
p_property_str = 'c_property'
|
||||||
|
|
||||||
|
for k in range(5, 6):
|
||||||
|
recall_list = []
|
||||||
|
for group_number in range(1, 6):
|
||||||
|
test_csv = pd.read_csv(f'translation/0.result/{group_number}/test_p.csv', low_memory=False)
|
||||||
|
test_csv = pd.read_csv(f'post_process/tfidf_class/0.class_document/distilbert/{group_number}/test_p_c_r.csv', low_memory=False)
|
||||||
|
train_all_csv = pd.read_csv(f'data_preprocess/dataset/{group_number}/train_all.csv', low_memory=False)
|
||||||
|
|
||||||
|
test_csv['s_score'], test_csv['s_thing'], test_csv['s_property'], test_csv['s_correct'] = -1, '', '', False
|
||||||
|
duplicate_filtered = test_csv[test_csv['p_MDM']].copy()
|
||||||
|
train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')
|
||||||
|
duplicate_filtered['tag_description'] = duplicate_filtered['tag_description'].fillna('')
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby([p_thing_str, p_property_str]):
|
||||||
|
matching_train_data = train_all_csv[(train_all_csv['thing'] == p_thing) & (train_all_csv['property'] == p_property)]
|
||||||
|
if not matching_train_data.empty:
|
||||||
|
combined_descriptions = sub_group['tag_description'].tolist() + matching_train_data['tag_description'].tolist()
|
||||||
|
|
||||||
|
vectorizer = TfidfVectorizer(use_idf=True, token_pattern=r'\S+')
|
||||||
|
tfidf_matrix = vectorizer.fit_transform(combined_descriptions)
|
||||||
|
|
||||||
|
test_tfidf_matrix = tfidf_matrix[:len(sub_group)]
|
||||||
|
train_tfidf_matrix = tfidf_matrix[len(sub_group):]
|
||||||
|
|
||||||
|
distance_matrix = pairwise_distances(test_tfidf_matrix, train_tfidf_matrix, metric='cosine')
|
||||||
|
similarity_matrix = 1 - distance_matrix
|
||||||
|
|
||||||
|
for i, row in enumerate(similarity_matrix):
|
||||||
|
top_k_indices = np.argsort(row)[-k:]
|
||||||
|
sub_group.iloc[i, sub_group.columns.get_loc('s_score')] = row[top_k_indices].mean()
|
||||||
|
else:
|
||||||
|
sub_group['s_score'] = 0
|
||||||
|
|
||||||
|
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby([p_thing_str, p_property_str]):
|
||||||
|
best_index = sub_group.index.min() if (sub_group['s_score'] == -1).any() else sub_group['s_score'].idxmax()
|
||||||
|
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, p_thing_str]
|
||||||
|
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, p_property_str]
|
||||||
|
duplicate_filtered = duplicate_filtered.drop(sub_group.index.difference([best_index]))
|
||||||
|
|
||||||
|
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
||||||
|
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
||||||
|
(test_csv['property'] == test_csv['s_property']) &
|
||||||
|
(test_csv['MDM']))
|
||||||
|
|
||||||
|
mdm_true_count = test_csv['MDM'].sum()
|
||||||
|
s_correct_count = test_csv['s_correct'].sum()
|
||||||
|
recall = s_correct_count / mdm_true_count * 100
|
||||||
|
recall_list.append(recall)
|
||||||
|
|
||||||
|
if k == 5:
|
||||||
|
output_path = f'post_process/0.result/{group_number}/test_s.csv'
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
test_csv.to_csv(output_path, index=False)
|
||||||
|
print(f"test_s.csv saved for Group {group_number} at {output_path}, mdm:{mdm_true_count}, correct:{s_correct_count}, recall:{recall:.2f}%")
|
||||||
|
|
||||||
|
average_recall = np.mean(recall_list)
|
||||||
|
k_accuracies.append(average_recall)
|
||||||
|
print(f"k={k}, Average s_correct percentage: {average_recall:.2f}%")
|
||||||
|
|
||||||
|
overall_average_accuracy = np.mean(k_accuracies)
|
||||||
|
print(f"Overall average s_correct percentage across all k values: {overall_average_accuracy:.2f}%")
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import scipy.sparse as sp # 추가된 부분
|
||||||
|
|
||||||
|
total_s_correct_count = 0
|
||||||
|
total_mdm_true_count = 0
|
||||||
|
|
||||||
|
# Modified TF-IDF Vectorizer to modify IDF behavior
|
||||||
|
class ModifiedTfidfVectorizer(TfidfVectorizer):
|
||||||
|
def _tfidf_transform(self, X, copy=True):
|
||||||
|
"""Apply TF-IDF weighting to a sparse matrix X."""
|
||||||
|
if not self.use_idf:
|
||||||
|
return X
|
||||||
|
df = np.bincount(X.indices, minlength=X.shape[1])
|
||||||
|
n_samples, n_features = X.shape
|
||||||
|
df += 1 # to smooth idf weights by adding 1 to document frequencies
|
||||||
|
# Custom IDF: Logarithm of document frequency (df), rewarding common terms
|
||||||
|
idf = np.log(df + 1) # Modified IDF: log(1 + df)
|
||||||
|
self._idf_diag = sp.diags(idf, offsets=0, shape=(n_features, n_features), format='csr')
|
||||||
|
return X * self._idf_diag
|
||||||
|
|
||||||
|
for group_number in range(1, 6):
|
||||||
|
|
||||||
|
test_path = f'translation/0.result/{group_number}/test_p.csv'
|
||||||
|
ship_data_list_reference_doc_file_path = f'post_process/tfidf_class/0.class_document/{group_number}/sdl_class_rdoc.csv'
|
||||||
|
|
||||||
|
test_csv = pd.read_csv(test_path, low_memory=False)
|
||||||
|
sdl_rdoc = pd.read_csv(ship_data_list_reference_doc_file_path)
|
||||||
|
|
||||||
|
test_csv['s_score'] = -1
|
||||||
|
test_csv['s_thing'] = ''
|
||||||
|
test_csv['s_property'] = ''
|
||||||
|
test_csv['s_correct'] = False
|
||||||
|
|
||||||
|
duplicate_filtered = test_csv[test_csv['p_MDM']].copy()
|
||||||
|
|
||||||
|
thing_property_to_reference_doc = sdl_rdoc.set_index(['thing', 'property'])['tag_description'].to_dict()
|
||||||
|
|
||||||
|
for ships_idx, group in tqdm(duplicate_filtered.groupby('ships_idx'), desc=f"Processing duplicates for group {group_number}"):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
sub_group = sub_group.copy()
|
||||||
|
tag_descriptions = sub_group['tag_description'].tolist()
|
||||||
|
emtpy_ref = False
|
||||||
|
reference_doc = thing_property_to_reference_doc.get((p_thing, p_property), '')
|
||||||
|
if not reference_doc:
|
||||||
|
p_pattern = sub_group['p_pattern'].iloc[0]
|
||||||
|
sdl_match = sdl_rdoc[sdl_rdoc['pattern'] == p_pattern].sort_values(by='mapping_count', ascending=False).head(1)
|
||||||
|
emtpy_ref = True
|
||||||
|
if not sdl_match.empty:
|
||||||
|
reference_doc = sdl_match['tag_description'].iloc[0]
|
||||||
|
else:
|
||||||
|
sub_group['s_score'] = 0
|
||||||
|
print(f"Reference document is empty for p_thing: {p_thing}, p_property: {p_property}")
|
||||||
|
duplicate_filtered.update(sub_group)
|
||||||
|
continue
|
||||||
|
|
||||||
|
combined_descriptions = tag_descriptions + [reference_doc]
|
||||||
|
|
||||||
|
vectorizer = ModifiedTfidfVectorizer(use_idf=True, token_pattern=r'\S+', ngram_range=(1, 1))
|
||||||
|
tfidf_matrix = vectorizer.fit_transform(combined_descriptions)
|
||||||
|
|
||||||
|
test_tfidf_matrix = tfidf_matrix[:-1]
|
||||||
|
reference_vector = tfidf_matrix[-1]
|
||||||
|
|
||||||
|
distance_matrix = pairwise_distances(test_tfidf_matrix, reference_vector.reshape(1, -1), metric='euclidean')
|
||||||
|
similarity_matrix = 1 - distance_matrix
|
||||||
|
|
||||||
|
sub_group['s_score'] = similarity_matrix.flatten()
|
||||||
|
|
||||||
|
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
||||||
|
|
||||||
|
for ships_idx, group in tqdm(duplicate_filtered.groupby('ships_idx'), desc=f"Processing duplicates for group {group_number}"):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
if (sub_group['s_score'] == -1).any():
|
||||||
|
best_index = sub_group.index.min()
|
||||||
|
else:
|
||||||
|
best_index = sub_group['s_score'].idxmax()
|
||||||
|
row_position = sub_group.index.get_loc(best_index)
|
||||||
|
|
||||||
|
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, 'p_thing']
|
||||||
|
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, 'p_property']
|
||||||
|
|
||||||
|
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
||||||
|
|
||||||
|
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
||||||
|
(test_csv['property'] == test_csv['s_property']) &
|
||||||
|
(test_csv['MDM']))
|
||||||
|
|
||||||
|
mdm_true_count = test_csv['MDM'].sum()
|
||||||
|
s_correct_count = test_csv['s_correct'].sum()
|
||||||
|
|
||||||
|
total_s_correct_count += s_correct_count
|
||||||
|
total_mdm_true_count += mdm_true_count
|
||||||
|
|
||||||
|
print(f"Group {group_number} - s_correct count: {s_correct_count}")
|
||||||
|
print(f"Group {group_number} - MDM true count: {mdm_true_count}")
|
||||||
|
print(f"Group {group_number} - s_correct percentage: {(s_correct_count / mdm_true_count) * 100:.2f}%")
|
||||||
|
|
||||||
|
output_path = f'post_process/0.result/tfidf/{group_number}/test_s.csv'
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||||
|
|
||||||
|
average_s_correct_percentage = (total_s_correct_count / total_mdm_true_count) * 100
|
||||||
|
print(f"Total s_correct count: {total_s_correct_count}")
|
||||||
|
print(f"Total MDM true count: {total_mdm_true_count}")
|
||||||
|
print(f"Average s_correct percentage across all groups: {average_s_correct_percentage:.2f}%")
|
|
@ -0,0 +1,74 @@
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Initialize overall accuracy results
|
||||||
|
k_accuracies = []
|
||||||
|
|
||||||
|
for k in range(1, 53): # k를 1부터 52까지 수행
|
||||||
|
total_s_correct_count = 0
|
||||||
|
total_mdm_true_count = 0
|
||||||
|
|
||||||
|
for group_number in range(1, 6):
|
||||||
|
# test_csv = pd.read_csv(f'post_process/tfidf_class/0.class_document/{group_number}/test_p_c.csv', low_memory=False)
|
||||||
|
test_csv = pd.read_csv(f'translation/0.result/{group_number}/test_p.csv', low_memory=False)
|
||||||
|
train_all_csv = pd.read_csv(f'data_preprocess/dataset/{group_number}/train_all.csv', low_memory=False)
|
||||||
|
|
||||||
|
test_csv['s_score'], test_csv['s_thing'], test_csv['s_property'], test_csv['s_correct'] = -1, '', '', False
|
||||||
|
duplicate_filtered = test_csv[test_csv['p_MDM']].copy()
|
||||||
|
train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')
|
||||||
|
duplicate_filtered['tag_description'] = duplicate_filtered['tag_description'].fillna('')
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
matching_train_data = train_all_csv[(train_all_csv['thing'] == p_thing) & (train_all_csv['property'] == p_property)]
|
||||||
|
if not matching_train_data.empty:
|
||||||
|
combined_descriptions = sub_group['tag_description'].tolist() + matching_train_data['tag_description'].tolist()
|
||||||
|
|
||||||
|
# BoW 벡터화를 위한 CountVectorizer 사용
|
||||||
|
vectorizer = CountVectorizer(token_pattern=r'\S+')
|
||||||
|
bow_matrix = vectorizer.fit_transform(combined_descriptions).toarray()
|
||||||
|
|
||||||
|
test_bow_matrix = bow_matrix[:len(sub_group)]
|
||||||
|
train_bow_matrix = bow_matrix[len(sub_group):]
|
||||||
|
|
||||||
|
# 코사인 거리를 계산하고, 유사도로 변환 (1 - 거리)
|
||||||
|
distance_matrix = pairwise_distances(test_bow_matrix, train_bow_matrix, metric='euclidean')
|
||||||
|
similarity_matrix = 1 - distance_matrix
|
||||||
|
|
||||||
|
for i, row in enumerate(similarity_matrix):
|
||||||
|
top_k_indices = np.argsort(row)[-k:] # 가장 가까운 k개의 인덱스 (유사도 기준, 내림차순)
|
||||||
|
sub_group.iloc[i, sub_group.columns.get_loc('s_score')] = row[top_k_indices].mean() # 유사도를 s_score에 저장
|
||||||
|
else:
|
||||||
|
sub_group['s_score'] = 0
|
||||||
|
|
||||||
|
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
best_index = sub_group.index.min() if (sub_group['s_score'] == -1).any() else sub_group['s_score'].idxmax()
|
||||||
|
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, 'p_thing']
|
||||||
|
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, 'p_property']
|
||||||
|
duplicate_filtered = duplicate_filtered.drop(sub_group.index.difference([best_index]))
|
||||||
|
|
||||||
|
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
||||||
|
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
||||||
|
(test_csv['property'] == test_csv['s_property']) &
|
||||||
|
(test_csv['MDM']))
|
||||||
|
|
||||||
|
mdm_true_count = test_csv['MDM'].sum()
|
||||||
|
s_correct_count = test_csv['s_correct'].sum()
|
||||||
|
|
||||||
|
total_s_correct_count += s_correct_count
|
||||||
|
total_mdm_true_count += mdm_true_count
|
||||||
|
|
||||||
|
if total_mdm_true_count > 0:
|
||||||
|
average_s_correct_percentage = (total_s_correct_count / total_mdm_true_count) * 100
|
||||||
|
k_accuracies.append(average_s_correct_percentage)
|
||||||
|
print(f"k={k}, Average s_correct percentage: {average_s_correct_percentage:.2f}%")
|
||||||
|
|
||||||
|
# k의 평균 정확도 출력
|
||||||
|
overall_average_accuracy = np.mean(k_accuracies)
|
||||||
|
print(f"Overall average s_correct percentage across all k values: {overall_average_accuracy:.2f}%")
|
|
@ -0,0 +1,76 @@
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Initialize overall accuracy results
|
||||||
|
k_accuracies = []
|
||||||
|
|
||||||
|
for k in range(1, 53): # k를 1부터 52까지 수행
|
||||||
|
total_s_correct_count = 0
|
||||||
|
total_mdm_true_count = 0
|
||||||
|
|
||||||
|
for group_number in range(1, 6):
|
||||||
|
test_csv = pd.read_csv(f'translation/0.result/{group_number}/test_p.csv', low_memory=False)
|
||||||
|
train_all_csv = pd.read_csv(f'data_preprocess/dataset/{group_number}/train_all.csv', low_memory=False)
|
||||||
|
|
||||||
|
test_csv['s_score'], test_csv['s_thing'], test_csv['s_property'], test_csv['s_correct'] = -1, '', '', False
|
||||||
|
duplicate_filtered = test_csv[test_csv['p_MDM']].copy()
|
||||||
|
train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')
|
||||||
|
duplicate_filtered['tag_description'] = duplicate_filtered['tag_description'].fillna('')
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
matching_train_data = train_all_csv[(train_all_csv['thing'] == p_thing) & (train_all_csv['property'] == p_property)]
|
||||||
|
if not matching_train_data.empty:
|
||||||
|
combined_descriptions = sub_group['tag_description'].tolist() + matching_train_data['tag_description'].tolist()
|
||||||
|
|
||||||
|
# BoW 벡터화를 위한 CountVectorizer 사용 (binary=True로 설정)
|
||||||
|
vectorizer = CountVectorizer(binary=True, token_pattern=r'\S+')
|
||||||
|
bow_matrix = vectorizer.fit_transform(combined_descriptions).toarray()
|
||||||
|
|
||||||
|
# BoW를 Boolean 방식으로 변환
|
||||||
|
bow_matrix = bow_matrix.astype(bool)
|
||||||
|
|
||||||
|
test_bow_matrix = bow_matrix[:len(sub_group)]
|
||||||
|
train_bow_matrix = bow_matrix[len(sub_group):]
|
||||||
|
|
||||||
|
# 코사인 거리를 계산하고, 유사도로 변환 (1 - 거리)
|
||||||
|
distance_matrix = pairwise_distances(test_bow_matrix, train_bow_matrix, metric='euclidean')
|
||||||
|
similarity_matrix = 1 - distance_matrix
|
||||||
|
|
||||||
|
for i, row in enumerate(similarity_matrix):
|
||||||
|
top_k_indices = np.argsort(row)[-k:] # 가장 가까운 k개의 인덱스 (유사도 기준, 내림차순)
|
||||||
|
sub_group.iloc[i, sub_group.columns.get_loc('s_score')] = row[top_k_indices].mean() # 유사도를 s_score에 저장
|
||||||
|
else:
|
||||||
|
sub_group['s_score'] = 0
|
||||||
|
|
||||||
|
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
best_index = sub_group.index.min() if (sub_group['s_score'] == -1).any() else sub_group['s_score'].idxmax()
|
||||||
|
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, 'p_thing']
|
||||||
|
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, 'p_property']
|
||||||
|
duplicate_filtered = duplicate_filtered.drop(sub_group.index.difference([best_index]))
|
||||||
|
|
||||||
|
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
||||||
|
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
||||||
|
(test_csv['property'] == test_csv['s_property']) &
|
||||||
|
(test_csv['MDM']))
|
||||||
|
|
||||||
|
mdm_true_count = test_csv['MDM'].sum()
|
||||||
|
s_correct_count = test_csv['s_correct'].sum()
|
||||||
|
|
||||||
|
total_s_correct_count += s_correct_count
|
||||||
|
total_mdm_true_count += mdm_true_count
|
||||||
|
|
||||||
|
if total_mdm_true_count > 0:
|
||||||
|
average_s_correct_percentage = (total_s_correct_count / total_mdm_true_count) * 100
|
||||||
|
k_accuracies.append(average_s_correct_percentage)
|
||||||
|
print(f"k={k}, Average s_correct percentage: {average_s_correct_percentage:.2f}%")
|
||||||
|
|
||||||
|
# k의 평균 정확도 출력
|
||||||
|
overall_average_accuracy = np.mean(k_accuracies)
|
||||||
|
print(f"Overall average s_correct percentage across all k values: {overall_average_accuracy:.2f}%")
|
|
@ -0,0 +1,79 @@
|
||||||
|
import pandas as pd
|
||||||
|
from gensim.models import Word2Vec
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Function to compute the average Word2Vec vector for a sentence
|
||||||
|
def compute_sentence_vector(sentence, model, vector_size):
|
||||||
|
words = sentence.split()
|
||||||
|
word_vectors = [model.wv[word] for word in words if word in model.wv]
|
||||||
|
if len(word_vectors) > 0:
|
||||||
|
return np.mean(word_vectors, axis=0)
|
||||||
|
else:
|
||||||
|
return np.zeros(vector_size)
|
||||||
|
|
||||||
|
k_accuracies = []
|
||||||
|
|
||||||
|
for k in range(1, 53): # k를 1부터 52까지 수행
|
||||||
|
total_s_correct_count = 0
|
||||||
|
total_mdm_true_count = 0
|
||||||
|
|
||||||
|
for group_number in range(1, 6):
|
||||||
|
test_csv = pd.read_csv(f'translation/0.result/{group_number}/test_p.csv', low_memory=False)
|
||||||
|
train_all_csv = pd.read_csv(f'data_preprocess/dataset/{group_number}/train_all.csv', low_memory=False)
|
||||||
|
|
||||||
|
test_csv['s_score'], test_csv['s_thing'], test_csv['s_property'], test_csv['s_correct'] = -1, '', '', False
|
||||||
|
duplicate_filtered = test_csv[test_csv['p_MDM']].copy()
|
||||||
|
train_all_csv['tag_description'] = train_all_csv['tag_description'].fillna('')
|
||||||
|
duplicate_filtered['tag_description'] = duplicate_filtered['tag_description'].fillna('')
|
||||||
|
|
||||||
|
combined_tag_descriptions = train_all_csv['tag_description'].tolist() + duplicate_filtered['tag_description'].tolist()
|
||||||
|
sentences = [desc.split() for desc in combined_tag_descriptions]
|
||||||
|
vector_size = 20 # 벡터 크기 설정
|
||||||
|
model = Word2Vec(sentences, vector_size=vector_size, window=3, min_count=1, workers=4)
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
matching_train_data = train_all_csv[(train_all_csv['thing'] == p_thing) & (train_all_csv['property'] == p_property)]
|
||||||
|
if not matching_train_data.empty:
|
||||||
|
test_vectors = np.array([compute_sentence_vector(desc, model, vector_size) for desc in sub_group['tag_description']])
|
||||||
|
train_vectors = np.array([compute_sentence_vector(desc, model, vector_size) for desc in matching_train_data['tag_description']])
|
||||||
|
|
||||||
|
distance_matrix = pairwise_distances(test_vectors, train_vectors, metric='euclidean')
|
||||||
|
similarity_matrix = 1 - distance_matrix
|
||||||
|
|
||||||
|
for i, row in enumerate(similarity_matrix):
|
||||||
|
top_k_indices = np.argsort(row)[-k:]
|
||||||
|
sub_group.iloc[i, sub_group.columns.get_loc('s_score')] = float(row[top_k_indices].mean())
|
||||||
|
else:
|
||||||
|
sub_group['s_score'] = 0
|
||||||
|
|
||||||
|
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
||||||
|
|
||||||
|
for ships_idx, group in duplicate_filtered.groupby('ships_idx'):
|
||||||
|
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||||
|
best_index = sub_group.index.min() if (sub_group['s_score'] == -1).any() else sub_group['s_score'].idxmax()
|
||||||
|
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, 'p_thing']
|
||||||
|
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, 'p_property']
|
||||||
|
duplicate_filtered = duplicate_filtered.drop(sub_group.index.difference([best_index]))
|
||||||
|
|
||||||
|
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
||||||
|
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
||||||
|
(test_csv['property'] == test_csv['s_property']) &
|
||||||
|
(test_csv['MDM']))
|
||||||
|
|
||||||
|
mdm_true_count = test_csv['MDM'].sum()
|
||||||
|
s_correct_count = test_csv['s_correct'].sum()
|
||||||
|
|
||||||
|
total_s_correct_count += s_correct_count
|
||||||
|
total_mdm_true_count += mdm_true_count
|
||||||
|
|
||||||
|
if total_mdm_true_count > 0:
|
||||||
|
average_s_correct_percentage = (total_s_correct_count / total_mdm_true_count) * 100
|
||||||
|
k_accuracies.append(average_s_correct_percentage)
|
||||||
|
print(f"k={k}, Average s_correct percentage: {average_s_correct_percentage:.2f}%")
|
||||||
|
|
||||||
|
# k의 평균 정확도 출력
|
||||||
|
overall_average_accuracy = np.mean(k_accuracies)
|
||||||
|
print(f"Overall average s_correct percentage across all k values: {overall_average_accuracy:.2f}%")
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,105 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "KeyError",
|
||||||
|
"evalue": "'p_map'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
||||||
|
"File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"\u001b[0;31mKeyError\u001b[0m: 'p_map'",
|
||||||
|
"\nThe above exception was the direct cause of the following exception:\n",
|
||||||
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[1], line 21\u001b[0m\n\u001b[1;32m 18\u001b[0m combined_data \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mconcat(all_data, ignore_index\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# -1인 s_score 값을 제외하고 s_thing이 null이 아닌 데이터만 필터링\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m filtered_data \u001b[38;5;241m=\u001b[39m combined_data[(combined_data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms_thing\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mnotna() \u001b[38;5;241m&\u001b[39m (\u001b[43mcombined_data\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mp_map\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m))]\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# s_correct가 True인 경우와 False인 경우로 나눔\u001b[39;00m\n\u001b[1;32m 24\u001b[0m true_data \u001b[38;5;241m=\u001b[39m filtered_data[filtered_data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms_correct\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m]\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3807\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m 3810\u001b[0m ):\n\u001b[1;32m 3811\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 3814\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3815\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3816\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3817\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n",
|
||||||
|
"\u001b[0;31mKeyError\u001b[0m: 'p_map'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"# 그룹 번호 목록 설정\n",
|
||||||
|
"group_numbers = [1]\n",
|
||||||
|
"\n",
|
||||||
|
"# 데이터를 저장할 리스트 초기화\n",
|
||||||
|
"all_data = []\n",
|
||||||
|
"\n",
|
||||||
|
"# 각 그룹의 데이터를 읽어서 합침\n",
|
||||||
|
"for group_number in group_numbers:\n",
|
||||||
|
" file_path = f'../0.result/tfidf/{group_number}/test_s.csv'\n",
|
||||||
|
" data = pd.read_csv(file_path)\n",
|
||||||
|
" all_data.append(data)\n",
|
||||||
|
"\n",
|
||||||
|
"# 모든 그룹 데이터를 하나의 DataFrame으로 합침\n",
|
||||||
|
"combined_data = pd.concat(all_data, ignore_index=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# -1인 s_score 값을 제외하고 s_thing이 null이 아닌 데이터만 필터링\n",
|
||||||
|
"filtered_data = combined_data[(combined_data['s_thing'].notna() & (combined_data['p_map'] > 0))]\n",
|
||||||
|
"\n",
|
||||||
|
"# s_correct가 True인 경우와 False인 경우로 나눔\n",
|
||||||
|
"true_data = filtered_data[filtered_data['s_correct'] == True]\n",
|
||||||
|
"false_data = filtered_data[filtered_data['s_correct'] == False]\n",
|
||||||
|
"\n",
|
||||||
|
"# 공통된 bins 설정\n",
|
||||||
|
"bins = np.linspace(0, 1, 31) # 0부터 1까지 30개의 구간으로 나눔\n",
|
||||||
|
"\n",
|
||||||
|
"# 히스토그램 그리기\n",
|
||||||
|
"plt.figure(figsize=(14, 7))\n",
|
||||||
|
"\n",
|
||||||
|
"# s_correct가 True인 경우\n",
|
||||||
|
"plt.hist(true_data['s_score'], bins=bins, color='green', edgecolor='black', alpha=0.5, label='s_correct=True')\n",
|
||||||
|
"\n",
|
||||||
|
"# s_correct가 False인 경우\n",
|
||||||
|
"plt.hist(false_data['s_score'], bins=bins, color='red', edgecolor='black', alpha=0.5, label='s_correct=False')\n",
|
||||||
|
"\n",
|
||||||
|
"# 그래프 제목과 라벨 설정\n",
|
||||||
|
"plt.title('Distribution of s_score by s_correct (s_thing is not null)', fontsize=20)\n",
|
||||||
|
"plt.xlabel('s_score', fontsize=16)\n",
|
||||||
|
"plt.ylabel('Frequency', fontsize=16)\n",
|
||||||
|
"plt.xticks(fontsize=14)\n",
|
||||||
|
"plt.yticks(fontsize=14)\n",
|
||||||
|
"\n",
|
||||||
|
"# 범례 추가\n",
|
||||||
|
"plt.legend(fontsize=14)\n",
|
||||||
|
"\n",
|
||||||
|
"# 그래프 출력\n",
|
||||||
|
"plt.show()\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "base",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,148 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "30a1ff83e388495ab06f4b8177746d4b",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Saving the dataset (0/1 shards): 0%| | 0/6260 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "3d009461ac044864b674dc59898160b2",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Saving the dataset (0/1 shards): 0%| | 0/12969 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "ce28a831723d4b4698e6ce4a216c56db",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Saving the dataset (0/1 shards): 0%| | 0/2087 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Dataset saved to 'combined_data'\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"from datasets import Dataset, DatasetDict\n",
|
||||||
|
"\n",
|
||||||
|
"group_number = 5\n",
|
||||||
|
"mode = 'td_unit'\n",
|
||||||
|
"\n",
|
||||||
|
"def load_group_data(group_number):\n",
|
||||||
|
" group_folder = os.path.join('../../data_preprocess/dataset', str(group_number))\n",
|
||||||
|
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
||||||
|
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
||||||
|
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
||||||
|
" \n",
|
||||||
|
" if not all(os.path.exists(f) for f in [train_file_path, valid_file_path, test_file_path]):\n",
|
||||||
|
" raise FileNotFoundError(f\"Files for group {group_number} do not exist.\")\n",
|
||||||
|
" \n",
|
||||||
|
" return pd.read_csv(train_file_path), pd.read_csv(valid_file_path), pd.read_csv(test_file_path)\n",
|
||||||
|
"\n",
|
||||||
|
"train_data, valid_data, test_data = load_group_data(group_number)\n",
|
||||||
|
"\n",
|
||||||
|
"def process_df(df, mode='only_td'):\n",
|
||||||
|
" output_list = []\n",
|
||||||
|
" for idx, row in df.iterrows():\n",
|
||||||
|
" try:\n",
|
||||||
|
" if mode == 'only_td':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||||
|
" elif mode == 'tn_td':\n",
|
||||||
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||||
|
" elif mode == 'tn_td_min_max':\n",
|
||||||
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
|
" elif mode == 'td_min_max':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
|
" elif mode == 'td_unit':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" elif mode == 'tn_td_unit':\n",
|
||||||
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" elif mode == 'td_min_max_unit':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise ValueError(\"Invalid mode specified\")\n",
|
||||||
|
" \n",
|
||||||
|
" output_list.append({\n",
|
||||||
|
" 'translation': {\n",
|
||||||
|
" 'ships_idx': row['ships_idx'],\n",
|
||||||
|
" 'input': input_str,\n",
|
||||||
|
" 'thing_property': f\"<THING_START>{str(row['thing'])}<THING_END><PROPERTY_START>{str(row['property'])}<PROPERTY_END>\",\n",
|
||||||
|
" 'answer': f\"{str(row['thing'])} {str(row['property'])}\",\n",
|
||||||
|
" }\n",
|
||||||
|
" })\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error processing row at index {idx}: {e}\")\n",
|
||||||
|
" return output_list\n",
|
||||||
|
"\n",
|
||||||
|
"combined_dict = {\"mode\": mode, \"fold_group\": group_number}\n",
|
||||||
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||||
|
" json.dump(combined_dict, json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"combined_data = DatasetDict({\n",
|
||||||
|
" 'train': Dataset.from_list(process_df(train_data, mode=mode)),\n",
|
||||||
|
" 'test': Dataset.from_list(process_df(test_data, mode=mode)),\n",
|
||||||
|
" 'validation': Dataset.from_list(process_df(valid_data, mode=mode)),\n",
|
||||||
|
"})\n",
|
||||||
|
"combined_data.save_to_disk(f\"combined_data/{mode}/{group_number}\")\n",
|
||||||
|
"print(\"Dataset saved to 'combined_data'\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,359 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# t5 training for combined concatenated outputs (thing + property) \n",
|
||||||
|
"\n",
|
||||||
|
"refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "90f850a9e8324109808e45e40f0eea47",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Map: 0%| | 0/6260 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "34e221d3425d414a9fb749a3ee28ad81",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Map: 0%| | 0/12969 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "7c5504c54cba4520aa34d5a6a078a31d",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Map: 0%| | 0/2087 [00:00<?, ? examples/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"\n",
|
||||||
|
" <div>\n",
|
||||||
|
" \n",
|
||||||
|
" <progress value='1800' max='3920' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||||
|
" [1800/3920 13:48 < 16:16, 2.17 it/s, Epoch 36/80]\n",
|
||||||
|
" </div>\n",
|
||||||
|
" <table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: left;\">\n",
|
||||||
|
" <th>Step</th>\n",
|
||||||
|
" <th>Training Loss</th>\n",
|
||||||
|
" <th>Validation Loss</th>\n",
|
||||||
|
" <th>Bleu</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>200</td>\n",
|
||||||
|
" <td>2.654300</td>\n",
|
||||||
|
" <td>0.112380</td>\n",
|
||||||
|
" <td>26.397731</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>400</td>\n",
|
||||||
|
" <td>0.106600</td>\n",
|
||||||
|
" <td>0.035335</td>\n",
|
||||||
|
" <td>87.137364</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>600</td>\n",
|
||||||
|
" <td>0.044600</td>\n",
|
||||||
|
" <td>0.022964</td>\n",
|
||||||
|
" <td>89.884682</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>800</td>\n",
|
||||||
|
" <td>0.026300</td>\n",
|
||||||
|
" <td>0.018220</td>\n",
|
||||||
|
" <td>86.274312</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1000</td>\n",
|
||||||
|
" <td>0.017300</td>\n",
|
||||||
|
" <td>0.016252</td>\n",
|
||||||
|
" <td>86.389477</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1200</td>\n",
|
||||||
|
" <td>0.012400</td>\n",
|
||||||
|
" <td>0.015651</td>\n",
|
||||||
|
" <td>94.416285</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1400</td>\n",
|
||||||
|
" <td>0.011500</td>\n",
|
||||||
|
" <td>0.014833</td>\n",
|
||||||
|
" <td>91.596509</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1600</td>\n",
|
||||||
|
" <td>0.008800</td>\n",
|
||||||
|
" <td>0.015168</td>\n",
|
||||||
|
" <td>91.629519</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1800</td>\n",
|
||||||
|
" <td>0.006900</td>\n",
|
||||||
|
" <td>0.015042</td>\n",
|
||||||
|
" <td>95.375351</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table><p>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
|
||||||
|
"Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}\n",
|
||||||
|
"There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "",
|
||||||
|
"evalue": "",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
|
||||||
|
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
|
||||||
|
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
|
||||||
|
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from datasets import load_from_disk\n",
|
||||||
|
"import json\n",
|
||||||
|
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback\n",
|
||||||
|
"import evaluate\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"model_name = \"facebook/bart-base\"\n",
|
||||||
|
"train_epochs = 80\n",
|
||||||
|
"\n",
|
||||||
|
"# Load mode configuration\n",
|
||||||
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||||
|
" mode_dict = json.load(json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"mode_dict.update({\"model\": model_name, \"train_epochs\": train_epochs})\n",
|
||||||
|
"fold_group = mode_dict.get(\"fold_group\")\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||||
|
" json.dump(mode_dict, json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"mode = mode_dict.get(\"mode\", \"default_value\")\n",
|
||||||
|
"file_path = f'combined_data/{mode}/{fold_group}'\n",
|
||||||
|
"split_datasets = load_from_disk(file_path)\n",
|
||||||
|
"\n",
|
||||||
|
"# Load tokenizer and add special tokens\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||||
|
"additional_special_tokens = [\n",
|
||||||
|
" \"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\",\n",
|
||||||
|
" \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \n",
|
||||||
|
" \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\",\n",
|
||||||
|
" \"<UNIT_START>\", \"<UNIT_END>\"\n",
|
||||||
|
"]\n",
|
||||||
|
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
||||||
|
"\n",
|
||||||
|
"# Preprocess function for tokenization\n",
|
||||||
|
"def preprocess_function(examples):\n",
|
||||||
|
" inputs = [ex[\"input\"] for ex in examples['translation']]\n",
|
||||||
|
" targets = [ex[\"thing_property\"] for ex in examples['translation']]\n",
|
||||||
|
" return tokenizer(inputs, text_target=targets, max_length=64, truncation=True)\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_datasets = split_datasets.map(\n",
|
||||||
|
" preprocess_function, batched=True, remove_columns=split_datasets[\"train\"].column_names\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Load model and resize token embeddings\n",
|
||||||
|
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||||
|
"model.resize_token_embeddings(len(tokenizer))\n",
|
||||||
|
"\n",
|
||||||
|
"# Data collator for padding and batching\n",
|
||||||
|
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||||
|
"\n",
|
||||||
|
"# Load evaluation metric\n",
|
||||||
|
"metric = evaluate.load(\"sacrebleu\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Compute metrics function\n",
|
||||||
|
"def compute_metrics(eval_preds):\n",
|
||||||
|
" preds, labels = eval_preds\n",
|
||||||
|
" preds = preds[0] if isinstance(preds, tuple) else preds\n",
|
||||||
|
" \n",
|
||||||
|
" # Decode predictions and labels\n",
|
||||||
|
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||||
|
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id) # Replace padding tokens\n",
|
||||||
|
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Post-process decoding\n",
|
||||||
|
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
|
||||||
|
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
|
||||||
|
" \n",
|
||||||
|
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
||||||
|
" return {\"bleu\": result[\"score\"]}\n",
|
||||||
|
"\n",
|
||||||
|
"args = Seq2SeqTrainingArguments(\n",
|
||||||
|
" f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\",\n",
|
||||||
|
" save_strategy=\"steps\",\n",
|
||||||
|
" learning_rate=1e-5,\n",
|
||||||
|
" per_device_train_batch_size=32,\n",
|
||||||
|
" per_device_eval_batch_size=64,\n",
|
||||||
|
" auto_find_batch_size=True,\n",
|
||||||
|
" ddp_find_unused_parameters=False,\n",
|
||||||
|
" weight_decay=0.01,\n",
|
||||||
|
" save_total_limit=1,\n",
|
||||||
|
" num_train_epochs=train_epochs,\n",
|
||||||
|
" predict_with_generate=True,\n",
|
||||||
|
" bf16=True,\n",
|
||||||
|
" push_to_hub=False,\n",
|
||||||
|
" evaluation_strategy=\"steps\",\n",
|
||||||
|
" eval_steps=200,\n",
|
||||||
|
" save_steps=200, \n",
|
||||||
|
" logging_steps=200, \n",
|
||||||
|
" load_best_model_at_end=True, \n",
|
||||||
|
" lr_scheduler_type=\"linear\",\n",
|
||||||
|
" warmup_steps=100,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Define the EarlyStoppingCallback\n",
|
||||||
|
"early_stopping_callback = EarlyStoppingCallback(\n",
|
||||||
|
" early_stopping_patience=2\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer = Seq2SeqTrainer(\n",
|
||||||
|
" model,\n",
|
||||||
|
" args,\n",
|
||||||
|
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||||
|
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||||
|
" data_collator=data_collator,\n",
|
||||||
|
" tokenizer=tokenizer,\n",
|
||||||
|
" compute_metrics=compute_metrics,\n",
|
||||||
|
" callbacks=[early_stopping_callback] \n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer.train()\n",
|
||||||
|
"os._exit(0)\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "base",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,316 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Goal: end to end inference and evaluation\n",
|
||||||
|
"\n",
|
||||||
|
"given a csv, make predictions and evaluate predictions, then return results in a csv"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The test_dataset contains 12938 items.\n",
|
||||||
|
"Making inference on test set\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"12938it [02:37, 82.28it/s] "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Inference done.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||||
|
" mode_dict = json.load(json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"mode = mode_dict.get(\"mode\", \"none\")\n",
|
||||||
|
"model_name = mode_dict.get(\"model\", \"none\")\n",
|
||||||
|
"train_epochs = mode_dict.get(\"train_epochs\", \"none\")\n",
|
||||||
|
"fold_group = mode_dict.get(\"fold_group\", \"none\")\n",
|
||||||
|
"\n",
|
||||||
|
"base_dir = f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\"\n",
|
||||||
|
"checkpoints = [d for d in os.listdir(base_dir) if d.startswith(\"checkpoint-\")]\n",
|
||||||
|
"\n",
|
||||||
|
"model_checkpoint = os.path.join(base_dir, checkpoints[0]) if checkpoints else None\n",
|
||||||
|
"\n",
|
||||||
|
"data_path = f\"../../data_preprocess/dataset/{fold_group}/test.csv\"\n",
|
||||||
|
"\n",
|
||||||
|
"try:\n",
|
||||||
|
" df = pd.read_csv(data_path)\n",
|
||||||
|
"except UnicodeDecodeError:\n",
|
||||||
|
" df = pd.read_csv(data_path, encoding='ISO-8859-1')\n",
|
||||||
|
"\n",
|
||||||
|
"df = df.dropna(subset=['tag_description']).reset_index(drop=True)\n",
|
||||||
|
"\n",
|
||||||
|
"df_org = df.copy()\n",
|
||||||
|
"df[['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']] = df[['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']].astype(\"string\")\n",
|
||||||
|
"\n",
|
||||||
|
"from datasets import Dataset\n",
|
||||||
|
"\n",
|
||||||
|
"def process_df(df, mode='only_td'):\n",
|
||||||
|
" output_list = []\n",
|
||||||
|
" for _, row in df.iterrows():\n",
|
||||||
|
" try:\n",
|
||||||
|
" if mode == 'only_td':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||||
|
" elif mode == 'tn_td':\n",
|
||||||
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||||
|
" elif mode == 'tn_td_min_max':\n",
|
||||||
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
|
" elif mode == 'td_min_max':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
|
" elif mode == 'td_unit':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" elif mode == 'tn_td_unit':\n",
|
||||||
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" elif mode == 'td_min_max_unit':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise ValueError(\"Invalid mode specified\")\n",
|
||||||
|
"\n",
|
||||||
|
" output_list.append({\n",
|
||||||
|
" 'translation': {\n",
|
||||||
|
" 'ships_idx': row['ships_idx'],\n",
|
||||||
|
" 'input': input_str,\n",
|
||||||
|
" 'thing_property': f\"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>\",\n",
|
||||||
|
" 'answer_thing': row['thing'],\n",
|
||||||
|
" 'answer_property': row['property'],\n",
|
||||||
|
" 'MDM': row['MDM'],\n",
|
||||||
|
" }\n",
|
||||||
|
" })\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error processing row: {e}\")\n",
|
||||||
|
" return output_list\n",
|
||||||
|
"\n",
|
||||||
|
"processed_data = process_df(df, mode=mode)\n",
|
||||||
|
"test_dataset = Dataset.from_list(processed_data)\n",
|
||||||
|
"print(f\"The test_dataset contains {len(test_dataset)} items.\")\n",
|
||||||
|
"\n",
|
||||||
|
"from transformers.pipelines.pt_utils import KeyDataset\n",
|
||||||
|
"from transformers import pipeline, BartTokenizer\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"# Use BartTokenizer for BART inference\n",
|
||||||
|
"tokenizer = BartTokenizer.from_pretrained(model_name, return_tensors=\"pt\")\n",
|
||||||
|
"additional_special_tokens = [\n",
|
||||||
|
" \"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \n",
|
||||||
|
" \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \n",
|
||||||
|
" \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \n",
|
||||||
|
" \"<UNIT_START>\", \"<UNIT_END>\"\n",
|
||||||
|
"]\n",
|
||||||
|
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
||||||
|
"\n",
|
||||||
|
"# Use BART model for inference\n",
|
||||||
|
"pipe = pipeline(\"text2text-generation\", model=model_checkpoint, tokenizer=tokenizer, return_tensors=True, max_length=128, device=0)\n",
|
||||||
|
"\n",
|
||||||
|
"# Check what token-ids the special tokens are\n",
|
||||||
|
"thing_start_id = tokenizer.convert_tokens_to_ids(\"<THING_START>\")\n",
|
||||||
|
"thing_end_id = tokenizer.convert_tokens_to_ids(\"<THING_END>\")\n",
|
||||||
|
"property_start_id = tokenizer.convert_tokens_to_ids(\"<PROPERTY_START>\")\n",
|
||||||
|
"property_end_id = tokenizer.convert_tokens_to_ids(\"<PROPERTY_END>\")\n",
|
||||||
|
"\n",
|
||||||
|
"def extract_seq(tokens, start_value, end_value):\n",
|
||||||
|
" if start_value in tokens and end_value in tokens:\n",
|
||||||
|
" return tokens[tokens.index(start_value)+1:tokens.index(end_value)]\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
"def extract_seq_from_output(output):\n",
|
||||||
|
" tokens = output[0][\"generated_token_ids\"].tolist()\n",
|
||||||
|
" p_thing = tokenizer.decode(extract_seq(tokens, thing_start_id, thing_end_id)) if thing_start_id in tokens and thing_end_id in tokens else None\n",
|
||||||
|
" p_property = tokenizer.decode(extract_seq(tokens, property_start_id, property_end_id)) if property_start_id in tokens and property_end_id in tokens else None\n",
|
||||||
|
" return p_thing, p_property\n",
|
||||||
|
"\n",
|
||||||
|
"# Inference and storing predictions\n",
|
||||||
|
"p_thing_list = []\n",
|
||||||
|
"p_property_list = []\n",
|
||||||
|
"print(\"Making inference on test set\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Process the test set through the pipeline and generate predictions\n",
|
||||||
|
"for out in tqdm(pipe(KeyDataset(test_dataset[\"translation\"], \"input\"), batch_size=256)):\n",
|
||||||
|
" p_thing, p_property = extract_seq_from_output(out)\n",
|
||||||
|
" p_thing_list.append(p_thing)\n",
|
||||||
|
" p_property_list.append(p_property)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Inference done.\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Thing prediction accuracy: 0.9793861658268438\n",
|
||||||
|
"Correct thing predictions: 2138, Incorrect thing predictions: 45\n",
|
||||||
|
"Property prediction accuracy: 0.9752633989922126\n",
|
||||||
|
"Correct property predictions: 2129, Incorrect property predictions: 10809\n",
|
||||||
|
"total accuracy: 0.9601465872652314\n",
|
||||||
|
"Correct total predictions: 2096, Incorrect total predictions: 87\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"answer_thing = [item['answer_thing'] for item in test_dataset[\"translation\"]]\n",
|
||||||
|
"answer_property = [item['answer_property'] for item in test_dataset[\"translation\"]]\n",
|
||||||
|
"mdm_list = [item['MDM'] for item in test_dataset[\"translation\"]]\n",
|
||||||
|
"\n",
|
||||||
|
"mdm_count = 0\n",
|
||||||
|
"for i in range(len(mdm_list)):\n",
|
||||||
|
" if(mdm_list[i] == \"True\"):mdm_count = mdm_count + 1 \n",
|
||||||
|
"\n",
|
||||||
|
"def correctness_test(input, reference, mdm_list):\n",
|
||||||
|
" assert(len(input) == len(reference))\n",
|
||||||
|
" correctness_list = []\n",
|
||||||
|
" for i in range(len(input)):\n",
|
||||||
|
" if(mdm_list[i] == \"True\"):\n",
|
||||||
|
" correctness_list.append(input[i] == reference[i])\n",
|
||||||
|
" else:correctness_list.append(False)\n",
|
||||||
|
" return correctness_list\n",
|
||||||
|
"\n",
|
||||||
|
"# Compare with answer to evaluate correctness\n",
|
||||||
|
"thing_correctness = correctness_test(p_thing_list, answer_thing, mdm_list)\n",
|
||||||
|
"property_correctness = correctness_test(p_property_list, answer_property, mdm_list)\n",
|
||||||
|
"\n",
|
||||||
|
"correctness_mdm = []\n",
|
||||||
|
"for i in range(len(mdm_list)):\n",
|
||||||
|
" if(thing_correctness[i] & property_correctness[i]):\n",
|
||||||
|
" correctness_mdm.append(True)\n",
|
||||||
|
" else: \n",
|
||||||
|
" correctness_mdm.append(False)\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"# Calculate accuracy\n",
|
||||||
|
"thing_accuracy = sum(thing_correctness) / mdm_count\n",
|
||||||
|
"property_accuracy = sum(property_correctness) / mdm_count\n",
|
||||||
|
"total_accuracy = sum(correctness_mdm) / mdm_count\n",
|
||||||
|
"\n",
|
||||||
|
"# Count True/False values\n",
|
||||||
|
"thing_true_count = thing_correctness.count(True)\n",
|
||||||
|
"thing_false_count = 0\n",
|
||||||
|
"for i in range(len(thing_correctness)):\n",
|
||||||
|
" if mdm_list[i] == \"True\" and thing_correctness[i] == False:\n",
|
||||||
|
" thing_false_count += 1\n",
|
||||||
|
"\n",
|
||||||
|
"property_true_count = property_correctness.count(True)\n",
|
||||||
|
"property_false_count = property_correctness.count(False)\n",
|
||||||
|
"total_true_count = correctness_mdm.count(True)\n",
|
||||||
|
"total_false_count = mdm_count - correctness_mdm.count(True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Print results\n",
|
||||||
|
"print(\"Thing prediction accuracy:\", thing_accuracy)\n",
|
||||||
|
"print(f\"Correct thing predictions: {thing_true_count}, Incorrect thing predictions: {thing_false_count}\")\n",
|
||||||
|
"print(\"Property prediction accuracy:\", property_accuracy)\n",
|
||||||
|
"print(f\"Correct property predictions: {property_true_count}, Incorrect property predictions: {property_false_count}\")\n",
|
||||||
|
"print(\"total accuracy:\", total_accuracy)\n",
|
||||||
|
"print(f\"Correct total predictions: {total_true_count}, Incorrect total predictions: {total_false_count}\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Create a DataFrame with the results\n",
|
||||||
|
"dict = {\n",
|
||||||
|
" 'p_thing': p_thing_list,\n",
|
||||||
|
" 'p_property': p_property_list,\n",
|
||||||
|
" 'p_thing_correct': thing_correctness,\n",
|
||||||
|
" 'p_property_correct': property_correctness\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"df_pred = pd.DataFrame(dict)\n",
|
||||||
|
"\n",
|
||||||
|
"# Read the mode from the JSON file\n",
|
||||||
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||||
|
" mode_dict = json.load(json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"# Add the model key to the dictionary\n",
|
||||||
|
"mode_dict[\"model\"] = model_name\n",
|
||||||
|
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
||||||
|
"\n",
|
||||||
|
"# Save the updated dictionary back to the JSON file\n",
|
||||||
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||||
|
" json.dump(mode_dict, json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Check if the file exists and is not empty\n",
|
||||||
|
"if os.path.exists(\"results.json\") and os.path.getsize(\"results.json\") > 0:\n",
|
||||||
|
" # Read the existing results.json file\n",
|
||||||
|
" with open(\"results.json\", \"r\") as json_file:\n",
|
||||||
|
" try:\n",
|
||||||
|
" results_dict = json.load(json_file)\n",
|
||||||
|
" except json.JSONDecodeError:\n",
|
||||||
|
" results_dict = {}\n",
|
||||||
|
"else:\n",
|
||||||
|
" results_dict = {}\n",
|
||||||
|
"\n",
|
||||||
|
"# Add the new model_checkpoint key with the accuracy values as an object\n",
|
||||||
|
"\n",
|
||||||
|
"model_key = model_checkpoint \n",
|
||||||
|
"\n",
|
||||||
|
"results_dict[model_key] = {\n",
|
||||||
|
" \"thing_accuracy\": thing_accuracy,\n",
|
||||||
|
" \"thing_true\": thing_true_count,\n",
|
||||||
|
" \"thing_false\": thing_false_count,\n",
|
||||||
|
" \"property_accuracy\": property_accuracy,\n",
|
||||||
|
" \"property_true\": property_true_count,\n",
|
||||||
|
" \"property_false\": property_false_count,\n",
|
||||||
|
" \"total_accuracy\": total_accuracy,\n",
|
||||||
|
" \"total_true\": total_true_count,\n",
|
||||||
|
" \"total_false\": total_false_count \n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# Save the updated dictionary back to the results.json file\n",
|
||||||
|
"with open(\"results.json\", \"w\") as json_file:\n",
|
||||||
|
" json.dump(results_dict, json_file, indent=4)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -2,74 +2,18 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Loaded data for group 1:\n",
|
|
||||||
"Train data shape: (6125, 16)\n",
|
|
||||||
"Valid data shape: (2042, 16)\n",
|
|
||||||
"Test data shape: (14719, 15)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"import os\n",
|
|
||||||
"# Example usage:1\n",
|
|
||||||
"group_number = 1 # You can change this to any group number you want to load (1, 2, 3, 4, or 5)\n",
|
|
||||||
"\n",
|
|
||||||
"# Select the mode for processing\n",
|
|
||||||
"mode = 'tn_td_unit' # Change this to 'only_td', 'tn_td', etc., as needed\n",
|
|
||||||
"\n",
|
|
||||||
"def load_group_data(group_number):\n",
|
|
||||||
" # Define the folder path based on the group number\n",
|
|
||||||
" group_folder = os.path.join('../../data_preprocess/dataset', str(group_number))\n",
|
|
||||||
" \n",
|
|
||||||
" # Define file paths for train, valid, and test datasets\n",
|
|
||||||
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
|
||||||
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
|
||||||
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
|
||||||
" \n",
|
|
||||||
" # Check if the files exist\n",
|
|
||||||
" if not os.path.exists(train_file_path) or not os.path.exists(valid_file_path) or not os.path.exists(test_file_path):\n",
|
|
||||||
" raise FileNotFoundError(f\"One or more files for group {group_number} do not exist.\")\n",
|
|
||||||
" \n",
|
|
||||||
" # Load the CSV files into DataFrames\n",
|
|
||||||
" train_data = pd.read_csv(train_file_path)\n",
|
|
||||||
" valid_data = pd.read_csv(valid_file_path)\n",
|
|
||||||
" test_data = pd.read_csv(test_file_path)\n",
|
|
||||||
" \n",
|
|
||||||
" return train_data, valid_data, test_data\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"try:\n",
|
|
||||||
" train_data, valid_data, test_data = load_group_data(group_number)\n",
|
|
||||||
" print(f\"Loaded data for group {group_number}:\")\n",
|
|
||||||
" print(f\"Train data shape: {train_data.shape}\")\n",
|
|
||||||
" print(f\"Valid data shape: {valid_data.shape}\")\n",
|
|
||||||
" print(f\"Test data shape: {test_data.shape}\")\n",
|
|
||||||
"except FileNotFoundError as e:\n",
|
|
||||||
" print(e)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "313f98ef12eb442bac319282e5ffe5d6",
|
"model_id": "7d3d34e404f94388a89f0c9b1aa814e6",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"Saving the dataset (0/1 shards): 0%| | 0/6125 [00:00<?, ? examples/s]"
|
"Saving the dataset (0/1 shards): 0%| | 0/6260 [00:00<?, ? examples/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -78,12 +22,12 @@
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "0c1834a4e7264a969085ad609320fdd6",
|
"model_id": "7b49ec520b674b39b34a8c28ff480716",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"Saving the dataset (0/1 shards): 0%| | 0/14719 [00:00<?, ? examples/s]"
|
"Saving the dataset (0/1 shards): 0%| | 0/12969 [00:00<?, ? examples/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -92,12 +36,12 @@
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "464f88daab334658aac93305ea6dac71",
|
"model_id": "c06c7ee55f174bb5b030983c52adbace",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"Saving the dataset (0/1 shards): 0%| | 0/2042 [00:00<?, ? examples/s]"
|
"Saving the dataset (0/1 shards): 0%| | 0/2087 [00:00<?, ? examples/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -112,26 +56,43 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import os\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"from datasets import Dataset, DatasetDict\n",
|
"from datasets import Dataset, DatasetDict\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Function to process DataFrame based on mode\n",
|
"group_number = 5\n",
|
||||||
|
"mode = 'td_unit'\n",
|
||||||
|
"\n",
|
||||||
|
"def load_group_data(group_number):\n",
|
||||||
|
" group_folder = os.path.join('../../data_preprocess/dataset', str(group_number))\n",
|
||||||
|
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
||||||
|
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
||||||
|
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
||||||
|
" \n",
|
||||||
|
" if not os.path.exists(train_file_path) or not os.path.exists(valid_file_path) or not os.path.exists(test_file_path):\n",
|
||||||
|
" raise FileNotFoundError(f\"Files for group {group_number} not found.\")\n",
|
||||||
|
" \n",
|
||||||
|
" return pd.read_csv(train_file_path), pd.read_csv(valid_file_path), pd.read_csv(test_file_path)\n",
|
||||||
|
"\n",
|
||||||
"def process_df(df, mode='only_td'):\n",
|
"def process_df(df, mode='only_td'):\n",
|
||||||
" output_list = []\n",
|
" output_list = []\n",
|
||||||
" for idx, row in df.iterrows():\n",
|
" for idx, row in df.iterrows():\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" if mode == 'only_td':\n",
|
" if mode == 'only_td':\n",
|
||||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
" input_str = f\"<TD_START>{row['tag_description']}<TD_END>\"\n",
|
||||||
" elif mode == 'tn_td':\n",
|
" elif mode == 'tn_td':\n",
|
||||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
" input_str = f\"<TN_START>{row['tag_name']}<TN_END><TD_START>{row['tag_description']}<TD_END>\"\n",
|
||||||
" elif mode == 'tn_td_min_max':\n",
|
" elif mode == 'tn_td_min_max':\n",
|
||||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
" input_str = f\"<TN_START>{row['tag_name']}<TN_END><TD_START>{row['tag_description']}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
" elif mode == 'td_min_max':\n",
|
" elif mode == 'td_min_max':\n",
|
||||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\" \n",
|
" input_str = f\"<TD_START>{row['tag_description']}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
" elif mode == 'td_unit':\n",
|
" elif mode == 'td_unit':\n",
|
||||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
" input_str = f\"<TD_START>{row['tag_description']}<TD_END><UNIT_START>{row['unit']}<UNIT_END>\"\n",
|
||||||
" elif mode == 'tn_td_unit':\n",
|
" elif mode == 'tn_td_unit':\n",
|
||||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
" input_str = f\"<TN_START>{row['tag_name']}<TN_END><TD_START>{row['tag_description']}<TD_END><UNIT_START>{row['unit']}<UNIT_END>\"\n",
|
||||||
|
" elif mode == 'td_min_max_unit':\n",
|
||||||
|
" input_str = f\"<TD_START>{row['tag_description']}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END><UNIT_START>{row['unit']}<UNIT_END>\"\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" raise ValueError(\"Invalid mode specified\")\n",
|
" raise ValueError(\"Invalid mode specified\")\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
@ -139,38 +100,27 @@
|
||||||
" 'translation': {\n",
|
" 'translation': {\n",
|
||||||
" 'ships_idx': row['ships_idx'],\n",
|
" 'ships_idx': row['ships_idx'],\n",
|
||||||
" 'input': input_str,\n",
|
" 'input': input_str,\n",
|
||||||
" 'thing_property': f\"<THING_START>{str(row['thing'])}<THING_END><PROPERTY_START>{str(row['property'])}<PROPERTY_END>\",\n",
|
" 'thing_property': f\"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>\",\n",
|
||||||
" 'answer': f\"{str(row['thing'])} {str(row['property'])}\",\n",
|
" 'answer': f\"{row['thing']} {row['property']}\",\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" })\n",
|
" })\n",
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" print(f\"Error processing row at index {idx}: {row}\")\n",
|
" print(f\"Error processing row at index {idx}: {e}\")\n",
|
||||||
" print(f\"Exception: {e}\")\n",
|
|
||||||
" return output_list\n",
|
" return output_list\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"train_data, valid_data, test_data = load_group_data(group_number)\n",
|
||||||
"# Combine the mode and group information into a single dictionary\n",
|
"combined_dict = {\"mode\": mode, \"fold_group\": group_number}\n",
|
||||||
"combined_dict = {\n",
|
|
||||||
" \"mode\": mode,\n",
|
|
||||||
" \"fold_group\": group_number\n",
|
|
||||||
"}\n",
|
|
||||||
"\n",
|
|
||||||
"# Save the combined dictionary to a JSON file\n",
|
|
||||||
"with open(\"mode.json\", \"w\") as json_file:\n",
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||||
" json.dump(combined_dict, json_file)\n",
|
" json.dump(combined_dict, json_file)\n",
|
||||||
" \n",
|
"\n",
|
||||||
"try:\n",
|
"combined_data = DatasetDict({\n",
|
||||||
" # Process the data and create a DatasetDict\n",
|
" 'train': Dataset.from_list(process_df(train_data, mode=mode)),\n",
|
||||||
" combined_data = DatasetDict({\n",
|
" 'test': Dataset.from_list(process_df(test_data, mode=mode)),\n",
|
||||||
" 'train': Dataset.from_list(process_df(train_data, mode=mode)),\n",
|
" 'validation': Dataset.from_list(process_df(valid_data, mode=mode)),\n",
|
||||||
" 'test': Dataset.from_list(process_df(test_data, mode=mode)),\n",
|
"})\n",
|
||||||
" 'validation': Dataset.from_list(process_df(valid_data, mode=mode)),\n",
|
"\n",
|
||||||
" })\n",
|
"combined_data.save_to_disk(f\"combined_data/{mode}/{group_number}\")\n",
|
||||||
" # Save the DatasetDict to disk\n",
|
"print(\"Dataset saved to 'combined_data'\")\n"
|
||||||
" combined_data.save_to_disk(f\"combined_data/{mode}/{group_number}\")\n",
|
|
||||||
" print(\"Dataset saved to 'combined_data'\")\n",
|
|
||||||
"except Exception as e:\n",
|
|
||||||
" print(f\"Error creating DatasetDict: {e}\")"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
|
@ -13,124 +13,6 @@
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"The mode has been set to: tn_td_unit\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "d8d70681f4594917b7af4583a4237168",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Map: 0%| | 0/6125 [00:00<?, ? examples/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "106e0cefe50c40f0a83371693cf48cf7",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Map: 0%| | 0/14719 [00:00<?, ? examples/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "952f8ec73df0418490cb43beaaf5a7df",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Map: 0%| | 0/2042 [00:00<?, ? examples/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# import data and load dataset\n",
|
|
||||||
"from datasets import load_from_disk\n",
|
|
||||||
"import json\n",
|
|
||||||
"from transformers import AutoTokenizer\n",
|
|
||||||
"\n",
|
|
||||||
"model_name = \"t5-base\"\n",
|
|
||||||
"train_epochs = 80\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# Read the mode from the JSON file\n",
|
|
||||||
"with open(\"mode.json\", \"r\") as json_file:\n",
|
|
||||||
" mode_dict = json.load(json_file)\n",
|
|
||||||
"\n",
|
|
||||||
"# Add the model key to the dictionary\n",
|
|
||||||
"mode_dict[\"model\"] = model_name\n",
|
|
||||||
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
|
||||||
"\n",
|
|
||||||
"# Access the fold_group value\n",
|
|
||||||
"fold_group = mode_dict.get(\"fold_group\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Save the updated dictionary back to the JSON file\n",
|
|
||||||
"with open(\"mode.json\", \"w\") as json_file:\n",
|
|
||||||
" json.dump(mode_dict, json_file)\n",
|
|
||||||
"\n",
|
|
||||||
"# Set the mode variable from the JSON content\n",
|
|
||||||
"mode = mode_dict.get(\"mode\", \"default_value\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"The mode has been set to: {mode}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Path to saved combined_dataset\n",
|
|
||||||
"file_path = f'combined_data/{mode}/{fold_group}'\n",
|
|
||||||
"split_datasets = load_from_disk(file_path)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
" \n",
|
|
||||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
|
||||||
"# Define additional special tokens\n",
|
|
||||||
"# additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\"]\n",
|
|
||||||
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
|
||||||
"# Add the additional special tokens to the tokenizer\n",
|
|
||||||
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
|
||||||
"\n",
|
|
||||||
"max_length = 64\n",
|
|
||||||
"\n",
|
|
||||||
"def preprocess_function(examples):\n",
|
|
||||||
" inputs = [ex[\"input\"] for ex in examples['translation']]\n",
|
|
||||||
" targets = [ex[\"thing_property\"] for ex in examples['translation']]\n",
|
|
||||||
" # text_target sets the corresponding label to inputs\n",
|
|
||||||
" # there is no need to create a separate 'labels'\n",
|
|
||||||
" model_inputs = tokenizer(\n",
|
|
||||||
" inputs, text_target=targets, max_length=max_length, truncation=True\n",
|
|
||||||
" )\n",
|
|
||||||
" return model_inputs\n",
|
|
||||||
"\n",
|
|
||||||
"# map method maps preprocess_function to [train, valid, test] datasets of the datasetDict\n",
|
|
||||||
"tokenized_datasets = split_datasets.map(\n",
|
|
||||||
" preprocess_function,\n",
|
|
||||||
" batched=True,\n",
|
|
||||||
" remove_columns=split_datasets[\"train\"].column_names,\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
|
@ -146,44 +28,204 @@
|
||||||
"\n",
|
"\n",
|
||||||
" <div>\n",
|
" <div>\n",
|
||||||
" \n",
|
" \n",
|
||||||
" <progress value='3840' max='3840' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
" <progress value='3140' max='3920' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||||
" [3840/3840 42:37, Epoch 80/80]\n",
|
" [3140/3920 05:42 < 01:25, 9.17 it/s, Epoch 64.06/80]\n",
|
||||||
" </div>\n",
|
" </div>\n",
|
||||||
" <table border=\"1\" class=\"dataframe\">\n",
|
" <table border=\"1\" class=\"dataframe\">\n",
|
||||||
" <thead>\n",
|
" <thead>\n",
|
||||||
" <tr style=\"text-align: left;\">\n",
|
" <tr style=\"text-align: left;\">\n",
|
||||||
" <th>Step</th>\n",
|
" <th>Step</th>\n",
|
||||||
" <th>Training Loss</th>\n",
|
" <th>Training Loss</th>\n",
|
||||||
|
" <th>Validation Loss</th>\n",
|
||||||
|
" <th>Bleu</th>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" </thead>\n",
|
" </thead>\n",
|
||||||
" <tbody>\n",
|
" <tbody>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
|
" <td>100</td>\n",
|
||||||
|
" <td>9.068100</td>\n",
|
||||||
|
" <td>1.485702</td>\n",
|
||||||
|
" <td>0.000000</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>200</td>\n",
|
||||||
|
" <td>0.886400</td>\n",
|
||||||
|
" <td>0.219002</td>\n",
|
||||||
|
" <td>20.999970</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>300</td>\n",
|
||||||
|
" <td>0.302500</td>\n",
|
||||||
|
" <td>0.100100</td>\n",
|
||||||
|
" <td>50.318311</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>400</td>\n",
|
||||||
|
" <td>0.168400</td>\n",
|
||||||
|
" <td>0.053922</td>\n",
|
||||||
|
" <td>52.052581</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
" <td>500</td>\n",
|
" <td>500</td>\n",
|
||||||
" <td>2.812300</td>\n",
|
" <td>0.113800</td>\n",
|
||||||
|
" <td>0.046394</td>\n",
|
||||||
|
" <td>53.469249</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>600</td>\n",
|
||||||
|
" <td>0.084500</td>\n",
|
||||||
|
" <td>0.040225</td>\n",
|
||||||
|
" <td>53.980484</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>700</td>\n",
|
||||||
|
" <td>0.066900</td>\n",
|
||||||
|
" <td>0.026786</td>\n",
|
||||||
|
" <td>58.959618</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>800</td>\n",
|
||||||
|
" <td>0.053300</td>\n",
|
||||||
|
" <td>0.025612</td>\n",
|
||||||
|
" <td>52.672595</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>900</td>\n",
|
||||||
|
" <td>0.042600</td>\n",
|
||||||
|
" <td>0.019917</td>\n",
|
||||||
|
" <td>58.475230</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <td>1000</td>\n",
|
" <td>1000</td>\n",
|
||||||
" <td>0.699300</td>\n",
|
" <td>0.038200</td>\n",
|
||||||
|
" <td>0.021234</td>\n",
|
||||||
|
" <td>52.335545</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1100</td>\n",
|
||||||
|
" <td>0.032500</td>\n",
|
||||||
|
" <td>0.021687</td>\n",
|
||||||
|
" <td>52.400191</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1200</td>\n",
|
||||||
|
" <td>0.030100</td>\n",
|
||||||
|
" <td>0.022106</td>\n",
|
||||||
|
" <td>59.836717</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1300</td>\n",
|
||||||
|
" <td>0.026800</td>\n",
|
||||||
|
" <td>0.020341</td>\n",
|
||||||
|
" <td>55.878989</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1400</td>\n",
|
||||||
|
" <td>0.023200</td>\n",
|
||||||
|
" <td>0.019192</td>\n",
|
||||||
|
" <td>53.356706</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <td>1500</td>\n",
|
" <td>1500</td>\n",
|
||||||
" <td>0.440900</td>\n",
|
" <td>0.022500</td>\n",
|
||||||
|
" <td>0.018187</td>\n",
|
||||||
|
" <td>59.718873</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1600</td>\n",
|
||||||
|
" <td>0.020900</td>\n",
|
||||||
|
" <td>0.017806</td>\n",
|
||||||
|
" <td>62.848480</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1700</td>\n",
|
||||||
|
" <td>0.017200</td>\n",
|
||||||
|
" <td>0.018625</td>\n",
|
||||||
|
" <td>62.796542</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1800</td>\n",
|
||||||
|
" <td>0.015500</td>\n",
|
||||||
|
" <td>0.020747</td>\n",
|
||||||
|
" <td>62.920445</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>1900</td>\n",
|
||||||
|
" <td>0.013800</td>\n",
|
||||||
|
" <td>0.027109</td>\n",
|
||||||
|
" <td>68.566983</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <td>2000</td>\n",
|
" <td>2000</td>\n",
|
||||||
" <td>0.332100</td>\n",
|
" <td>0.013900</td>\n",
|
||||||
|
" <td>0.024757</td>\n",
|
||||||
|
" <td>65.792365</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2100</td>\n",
|
||||||
|
" <td>0.011600</td>\n",
|
||||||
|
" <td>0.021626</td>\n",
|
||||||
|
" <td>68.714757</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2200</td>\n",
|
||||||
|
" <td>0.011800</td>\n",
|
||||||
|
" <td>0.025541</td>\n",
|
||||||
|
" <td>73.793641</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2300</td>\n",
|
||||||
|
" <td>0.011000</td>\n",
|
||||||
|
" <td>0.017915</td>\n",
|
||||||
|
" <td>71.351766</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2400</td>\n",
|
||||||
|
" <td>0.010500</td>\n",
|
||||||
|
" <td>0.020459</td>\n",
|
||||||
|
" <td>76.285575</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <td>2500</td>\n",
|
" <td>2500</td>\n",
|
||||||
" <td>0.276500</td>\n",
|
" <td>0.009700</td>\n",
|
||||||
|
" <td>0.019714</td>\n",
|
||||||
|
" <td>78.722420</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2600</td>\n",
|
||||||
|
" <td>0.008700</td>\n",
|
||||||
|
" <td>0.026323</td>\n",
|
||||||
|
" <td>73.858894</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2700</td>\n",
|
||||||
|
" <td>0.008600</td>\n",
|
||||||
|
" <td>0.023967</td>\n",
|
||||||
|
" <td>78.752238</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2800</td>\n",
|
||||||
|
" <td>0.008500</td>\n",
|
||||||
|
" <td>0.025074</td>\n",
|
||||||
|
" <td>78.772012</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <td>2900</td>\n",
|
||||||
|
" <td>0.008400</td>\n",
|
||||||
|
" <td>0.022061</td>\n",
|
||||||
|
" <td>83.261974</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <td>3000</td>\n",
|
" <td>3000</td>\n",
|
||||||
" <td>0.245900</td>\n",
|
" <td>0.008800</td>\n",
|
||||||
|
" <td>0.022081</td>\n",
|
||||||
|
" <td>80.992463</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <td>3500</td>\n",
|
" <td>3100</td>\n",
|
||||||
" <td>0.229300</td>\n",
|
" <td>0.007100</td>\n",
|
||||||
|
" <td>0.024494</td>\n",
|
||||||
|
" <td>81.058833</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" </tbody>\n",
|
" </tbody>\n",
|
||||||
"</table><p>"
|
"</table><p>"
|
||||||
|
@ -199,231 +241,228 @@
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
" warnings.warn(\n",
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
||||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"ename": "KeyboardInterrupt",
|
||||||
"text/plain": [
|
"evalue": "",
|
||||||
"TrainOutput(global_step=3840, training_loss=0.6754856963952383, metrics={'train_runtime': 2559.4201, 'train_samples_per_second': 191.45, 'train_steps_per_second': 1.5, 'total_flos': 3.156037495934976e+16, 'train_loss': 0.6754856963952383, 'epoch': 80.0})"
|
"output_type": "error",
|
||||||
]
|
"traceback": [
|
||||||
},
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
"execution_count": 2,
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||||
"metadata": {},
|
"Cell \u001b[0;32mIn[1], line 113\u001b[0m\n\u001b[1;32m 97\u001b[0m early_stopping_callback \u001b[38;5;241m=\u001b[39m EarlyStoppingCallback(\n\u001b[1;32m 98\u001b[0m early_stopping_patience\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[1;32m 99\u001b[0m \n\u001b[1;32m 100\u001b[0m )\n\u001b[1;32m 102\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 103\u001b[0m model,\n\u001b[1;32m 104\u001b[0m args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 110\u001b[0m callbacks\u001b[38;5;241m=\u001b[39m[early_stopping_callback] \n\u001b[1;32m 111\u001b[0m )\n\u001b[0;32m--> 113\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m os\u001b[38;5;241m.\u001b[39m_exit(\u001b[38;5;241m0\u001b[39m)\n",
|
||||||
"output_type": "execute_result"
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/accelerate/utils/memory.py:142\u001b[0m, in \u001b[0;36mfind_executable_batch_size.<locals>.decorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo executable batch size found, reached zero.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_reduce_batch_size(e):\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2206\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2207\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2208\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2209\u001b[0m ):\n\u001b[1;32m 2210\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2211\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/trainer.py:3147\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3145\u001b[0m scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 3146\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3147\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/accelerate/accelerator.py:2013\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2011\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 2012\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2013\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 524\u001b[0m )\n\u001b[0;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/autograd/__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/autograd/graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
|
||||||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"from datasets import load_from_disk\n",
|
||||||
"import os\n",
|
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"\n",
|
"from transformers import AutoTokenizer\n",
|
||||||
"# we use the pre-trained t5-base model\n",
|
"import os\n",
|
||||||
"from transformers import AutoModelForSeq2SeqLM\n",
|
"from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback\n",
|
||||||
"model_checkpoint = model_name\n",
|
|
||||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
|
|
||||||
"\n",
|
|
||||||
"# data collator\n",
|
|
||||||
"from transformers import DataCollatorForSeq2Seq\n",
|
|
||||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
|
||||||
"\n",
|
|
||||||
"# evaluation \n",
|
|
||||||
"import evaluate\n",
|
"import evaluate\n",
|
||||||
"metric = evaluate.load(\"sacrebleu\")\n",
|
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"model_name = \"google/t5-efficient-tiny\"\n",
|
||||||
|
"# google/t5-efficient-tiny\n",
|
||||||
|
"# google/t5-efficient-mini\n",
|
||||||
|
"# t5-small\n",
|
||||||
|
"# t5-base\n",
|
||||||
|
"\n",
|
||||||
|
"train_epochs = 80\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||||
|
" mode_dict = json.load(json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"mode_dict.update({\"model\": model_name, \"train_epochs\": train_epochs})\n",
|
||||||
|
"fold_group = mode_dict.get(\"fold_group\")\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||||
|
" json.dump(mode_dict, json_file)\n",
|
||||||
|
"\n",
|
||||||
|
"mode = mode_dict.get(\"mode\", \"default_value\")\n",
|
||||||
|
"file_path = f'combined_data/{mode}/{fold_group}'\n",
|
||||||
|
"split_datasets = load_from_disk(file_path)\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||||
|
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \n",
|
||||||
|
" \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \n",
|
||||||
|
" \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \n",
|
||||||
|
" \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
||||||
|
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
||||||
|
"\n",
|
||||||
|
"max_length = 64\n",
|
||||||
|
"\n",
|
||||||
|
"def preprocess_function(examples):\n",
|
||||||
|
" inputs = [ex[\"input\"] for ex in examples['translation']]\n",
|
||||||
|
" targets = [ex[\"thing_property\"] for ex in examples['translation']]\n",
|
||||||
|
" return tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_datasets = split_datasets.map(\n",
|
||||||
|
" preprocess_function,\n",
|
||||||
|
" batched=True,\n",
|
||||||
|
" remove_columns=split_datasets[\"train\"].column_names,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||||
|
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||||
|
"metric = evaluate.load(\"sacrebleu\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def compute_metrics(eval_preds):\n",
|
"def compute_metrics(eval_preds):\n",
|
||||||
" preds, labels = eval_preds\n",
|
" preds, labels = eval_preds\n",
|
||||||
" # In case the model returns more than the prediction logits\n",
|
|
||||||
" if isinstance(preds, tuple):\n",
|
" if isinstance(preds, tuple):\n",
|
||||||
" preds = preds[0]\n",
|
" preds = preds[0]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||||
"\n",
|
|
||||||
" # Replace -100s in the labels as we can't decode them\n",
|
|
||||||
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
||||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||||
"\n",
|
|
||||||
" # Some simple post-processing\n",
|
|
||||||
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
|
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
|
||||||
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
|
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
||||||
" return {\"bleu\": result[\"score\"]}\n",
|
" return {\"bleu\": result[\"score\"]}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from transformers import Seq2SeqTrainingArguments\n",
|
|
||||||
"\n",
|
|
||||||
"# load environment variables to disable GPU p2p mode for multi-gpu training without p2p mode\n",
|
|
||||||
"# not required for single-gpu training\n",
|
|
||||||
"import os\n",
|
|
||||||
"os.environ['NCCL_P2P_DISABLE'] = '1'\n",
|
"os.environ['NCCL_P2P_DISABLE'] = '1'\n",
|
||||||
"os.environ['NCCL_IB_DISABLE'] = '1'\n",
|
"os.environ['NCCL_IB_DISABLE'] = '1'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"args = Seq2SeqTrainingArguments(\n",
|
"args = Seq2SeqTrainingArguments(\n",
|
||||||
" f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\",\n",
|
" f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\",\n",
|
||||||
" evaluation_strategy=\"no\",\n",
|
" save_strategy=\"steps\",\n",
|
||||||
" # logging_dir=\"tensorboard-log\",\n",
|
" learning_rate=1e-3,\n",
|
||||||
" # logging_strategy=\"epoch\",\n",
|
|
||||||
" save_strategy=\"epoch\",\n",
|
|
||||||
" learning_rate=2e-5,\n",
|
|
||||||
" per_device_train_batch_size=32,\n",
|
" per_device_train_batch_size=32,\n",
|
||||||
" per_device_eval_batch_size=64,\n",
|
" per_device_eval_batch_size=64,\n",
|
||||||
" auto_find_batch_size=True,\n",
|
" auto_find_batch_size=True,\n",
|
||||||
|
@ -434,9 +473,21 @@
|
||||||
" predict_with_generate=True,\n",
|
" predict_with_generate=True,\n",
|
||||||
" bf16=True,\n",
|
" bf16=True,\n",
|
||||||
" push_to_hub=False,\n",
|
" push_to_hub=False,\n",
|
||||||
|
" evaluation_strategy=\"steps\",\n",
|
||||||
|
" eval_steps=100,\n",
|
||||||
|
" save_steps=100, \n",
|
||||||
|
" logging_steps=100, \n",
|
||||||
|
" load_best_model_at_end=True, \n",
|
||||||
|
" metric_for_best_model=\"bleu\",\n",
|
||||||
|
" lr_scheduler_type=\"linear\",\n",
|
||||||
|
" warmup_steps=100,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from transformers import Seq2SeqTrainer\n",
|
"# Define the EarlyStoppingCallback\n",
|
||||||
|
"early_stopping_callback = EarlyStoppingCallback(\n",
|
||||||
|
" early_stopping_patience=5,\n",
|
||||||
|
"\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"trainer = Seq2SeqTrainer(\n",
|
"trainer = Seq2SeqTrainer(\n",
|
||||||
" model,\n",
|
" model,\n",
|
||||||
|
@ -446,10 +497,11 @@
|
||||||
" data_collator=data_collator,\n",
|
" data_collator=data_collator,\n",
|
||||||
" tokenizer=tokenizer,\n",
|
" tokenizer=tokenizer,\n",
|
||||||
" compute_metrics=compute_metrics,\n",
|
" compute_metrics=compute_metrics,\n",
|
||||||
|
" callbacks=[early_stopping_callback] \n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Train the model\n",
|
"trainer.train()\n",
|
||||||
"trainer.train()"
|
"os._exit(0)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -15,13 +15,10 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"The mode has been set to: tn_td_unit t5-base\n",
|
"12938it [00:07, 1674.63it/s] \n"
|
||||||
"Using model checkpoint: train_1_t5-base_tn_td_unit_80/checkpoint-3840\n",
|
|
||||||
"Columns in df_org:\n",
|
|
||||||
"['thing', 'property', 'ships_idx', 'tag_name', 'tag_description', 'signal_type', 'min', 'max', 'unit', 'data_type', 'thing_pattern', 'property_pattern', 'pattern', 'MDM', 'org_tag_description']\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -29,76 +26,35 @@
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
"from transformers.pipelines.pt_utils import KeyDataset\n",
|
||||||
|
"from transformers import pipeline\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from transformers import AutoTokenizer\n",
|
||||||
|
"from datasets import Dataset\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Read the mode from the JSON file\n",
|
|
||||||
"with open(\"mode.json\", \"r\") as json_file:\n",
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||||
" mode_dict = json.load(json_file)\n",
|
" mode_dict = json.load(json_file)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"mode = mode_dict.get(\"mode\", \"none\")\n",
|
||||||
|
"model_name = mode_dict.get(\"model\", \"none\")\n",
|
||||||
|
"train_epochs = mode_dict.get(\"train_epochs\", \"none\")\n",
|
||||||
|
"fold_group = mode_dict.get(\"fold_group\", \"none\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Set the mode variable from the JSON content\n",
|
|
||||||
"mode = mode_dict.get(\"mode\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
||||||
"model_name = mode_dict.get(\"model\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
||||||
"train_epochs = mode_dict.get(\"train_epochs\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
||||||
"fold_group = mode_dict.get(\"fold_group\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"The mode has been set to: {mode} {model_name}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Define the base directory where checkpoints are stored\n",
|
|
||||||
"base_dir = f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\"\n",
|
"base_dir = f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\"\n",
|
||||||
|
"checkpoints = [d for d in os.listdir(base_dir) if d.startswith(\"checkpoint-\")]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# List all subdirectories in the base directory\n",
|
"model_checkpoint = os.path.join(base_dir, checkpoints[0]) if checkpoints else None\n",
|
||||||
"subdirectories = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Filter for checkpoint directories that match the pattern \"checkpoint-\"\n",
|
"data_path = f\"../../data_preprocess/dataset/{fold_group}/test.csv\"\n",
|
||||||
"checkpoints = [d for d in subdirectories if d.startswith(\"checkpoint-\")]\n",
|
|
||||||
"\n",
|
|
||||||
"# Select the latest checkpoint (the one with the highest number)\n",
|
|
||||||
"if checkpoints:\n",
|
|
||||||
" latest_checkpoint = checkpoints[0]\n",
|
|
||||||
" model_checkpoint = os.path.join(base_dir, latest_checkpoint)\n",
|
|
||||||
" print(f\"Using model checkpoint: {model_checkpoint}\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" print(\"No checkpoints were found.\")\n",
|
|
||||||
" model_checkpoint = None # Handle this case as needed\n",
|
|
||||||
"\n",
|
|
||||||
"# Load the data\n",
|
|
||||||
"data_path = f\"../../data_preprocess/dataset/{fold_group}/test.csv\" # Adjust the CSV file path as necessary\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"try:\n",
|
"try:\n",
|
||||||
" df = pd.read_csv(data_path)\n",
|
" df = pd.read_csv(data_path)\n",
|
||||||
"except UnicodeDecodeError:\n",
|
"except UnicodeDecodeError:\n",
|
||||||
" df = pd.read_csv(data_path, encoding='ISO-8859-1')\n",
|
" df = pd.read_csv(data_path, encoding='ISO-8859-1')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
|
||||||
"# Drop rows where 'tag_description' is NaN and reset the index\n",
|
|
||||||
"df = df.dropna(subset=['tag_description']).reset_index(drop=True)\n",
|
"df = df.dropna(subset=['tag_description']).reset_index(drop=True)\n",
|
||||||
"\n",
|
|
||||||
"# Preserve df_org\n",
|
|
||||||
"df_org = df.copy()\n",
|
"df_org = df.copy()\n",
|
||||||
"\n",
|
"df[['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']] = df[['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']].astype(\"string\")\n",
|
||||||
"# Print the column names of df_org\n",
|
|
||||||
"print(\"Columns in df_org:\")\n",
|
|
||||||
"print(df_org.columns.tolist())\n",
|
|
||||||
"\n",
|
|
||||||
"selected_columns = ['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']\n",
|
|
||||||
"df[selected_columns] = df[selected_columns].astype(\"string\")\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"The test_dataset contains 14718 items.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from datasets import Dataset\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"def process_df(df, mode='only_td'):\n",
|
"def process_df(df, mode='only_td'):\n",
|
||||||
" output_list = []\n",
|
" output_list = []\n",
|
||||||
|
@ -111,11 +67,13 @@
|
||||||
" elif mode == 'tn_td_min_max':\n",
|
" elif mode == 'tn_td_min_max':\n",
|
||||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
" elif mode == 'td_min_max':\n",
|
" elif mode == 'td_min_max':\n",
|
||||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\" \n",
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||||
" elif mode == 'td_unit':\n",
|
" elif mode == 'td_unit':\n",
|
||||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
" elif mode == 'tn_td_unit':\n",
|
" elif mode == 'tn_td_unit':\n",
|
||||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
|
" elif mode == 'td_min_max_unit':\n",
|
||||||
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" raise ValueError(\"Invalid mode specified\")\n",
|
" raise ValueError(\"Invalid mode specified\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -124,136 +82,64 @@
|
||||||
" 'ships_idx': row['ships_idx'],\n",
|
" 'ships_idx': row['ships_idx'],\n",
|
||||||
" 'input': input_str,\n",
|
" 'input': input_str,\n",
|
||||||
" 'thing_property': f\"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>\",\n",
|
" 'thing_property': f\"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>\",\n",
|
||||||
" 'answer_thing': f\"{row['thing']}\",\n",
|
" 'answer_thing': row['thing'],\n",
|
||||||
" 'answer_property': f\"{row['property']}\",\n",
|
" 'answer_property': row['property'],\n",
|
||||||
" 'MDM': f\"{row['MDM']}\",\n",
|
" 'MDM': row['MDM'],\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" })\n",
|
" })\n",
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" print(f\"Error processing row: {row}\")\n",
|
" print(f\"Error processing row: {e}\")\n",
|
||||||
" print(f\"Exception: {e}\")\n",
|
|
||||||
" return output_list\n",
|
" return output_list\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
|
||||||
"# Process the DataFrame\n",
|
|
||||||
"processed_data = process_df(df, mode=mode)\n",
|
"processed_data = process_df(df, mode=mode)\n",
|
||||||
"\n",
|
|
||||||
"# Create a Dataset object\n",
|
|
||||||
"test_dataset = Dataset.from_list(processed_data)\n",
|
"test_dataset = Dataset.from_list(processed_data)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print the number of items in the dataset\n",
|
|
||||||
"print(f\"The test_dataset contains {len(test_dataset)} items.\")\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from transformers.pipelines.pt_utils import KeyDataset\n",
|
|
||||||
"from transformers import pipeline\n",
|
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"import os\n",
|
|
||||||
"from transformers import AutoTokenizer\n",
|
|
||||||
"\n",
|
|
||||||
"tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors=\"pt\")\n",
|
"tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors=\"pt\")\n",
|
||||||
"# Define additional special tokens\n",
|
|
||||||
"# additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\"]\n",
|
|
||||||
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
||||||
"\n",
|
|
||||||
"# Add the additional special tokens to the tokenizer\n",
|
|
||||||
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
||||||
"# tokenizer.add_special_tokens({'sep_token': \"<SEP>\"})\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"pipe = pipeline(\"translation_XX_to_YY\", model=model_checkpoint, tokenizer=tokenizer, return_tensors=True, max_length=128, device=0)\n",
|
"pipe = pipeline(\"translation_XX_to_YY\", model=model_checkpoint, tokenizer=tokenizer, return_tensors=True, max_length=128, device=0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# check what token-ids the special tokens are\n",
|
"thing_start_id = tokenizer.convert_tokens_to_ids(\"<THING_START>\")\n",
|
||||||
"# tokenizer.encode(\"<THING_START><THING_END><PROPERTY_START><PROPERTY_END>\")\n",
|
"thing_end_id = tokenizer.convert_tokens_to_ids(\"<THING_END>\")\n",
|
||||||
|
"property_start_id = tokenizer.convert_tokens_to_ids(\"<PROPERTY_START>\")\n",
|
||||||
|
"property_end_id = tokenizer.convert_tokens_to_ids(\"<PROPERTY_END>\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def extract_seq(tokens, start_value, end_value):\n",
|
"def extract_seq(tokens, start_value, end_value):\n",
|
||||||
" if start_value not in tokens or end_value not in tokens:\n",
|
" if start_value in tokens and end_value in tokens:\n",
|
||||||
" return None # Or handle this case according to your requirements\n",
|
" return tokens[tokens.index(start_value)+1:tokens.index(end_value)]\n",
|
||||||
" start_id = tokens.index(start_value)\n",
|
" return None\n",
|
||||||
" end_id = tokens.index(end_value)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" return tokens[start_id+1:end_id]\n",
|
"def extract_seq_from_output(output):\n",
|
||||||
|
" tokens = output[0][\"translation_token_ids\"].tolist()\n",
|
||||||
|
" p_thing = tokenizer.decode(extract_seq(tokens, thing_start_id, thing_end_id)) if thing_start_id in tokens and thing_end_id in tokens else None\n",
|
||||||
|
" p_property = tokenizer.decode(extract_seq(tokens, property_start_id, property_end_id)) if property_start_id in tokens and property_end_id in tokens else None\n",
|
||||||
|
" return p_thing, p_property\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# problem, what if end tokens are not in?\n",
|
|
||||||
"def process_tensor_output(output):\n",
|
|
||||||
" tokens = output[0]['translation_token_ids'].tolist()\n",
|
|
||||||
" thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>\n",
|
|
||||||
" property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>\n",
|
|
||||||
" p_thing = None\n",
|
|
||||||
" p_property = None\n",
|
|
||||||
" if (thing_seq is not None):\n",
|
|
||||||
" p_thing = tokenizer.decode(thing_seq)\n",
|
|
||||||
" if (property_seq is not None):\n",
|
|
||||||
" p_property = tokenizer.decode(property_seq)\n",
|
|
||||||
" return p_thing, p_property"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"making inference on test set\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"14718it [00:44, 330.24it/s] "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"inference done\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"p_thing_list = []\n",
|
"p_thing_list = []\n",
|
||||||
"p_property_list = []\n",
|
"p_property_list = []\n",
|
||||||
"print(\"making inference on test set\")\n",
|
"\n",
|
||||||
"for out in tqdm(pipe(KeyDataset(test_dataset[\"translation\"], \"input\"), batch_size=256)):\n",
|
"for out in tqdm(pipe(KeyDataset(test_dataset[\"translation\"], \"input\"), batch_size=256)):\n",
|
||||||
" p_thing, p_property = process_tensor_output(out)\n",
|
" p_thing, p_property = extract_seq_from_output(out)\n",
|
||||||
" p_thing_list.append(p_thing)\n",
|
" p_thing_list.append(p_thing)\n",
|
||||||
" p_property_list.append(p_property)\n",
|
" p_property_list.append(p_property)\n"
|
||||||
"print(\"inference done\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Thing prediction accuracy: 0.9895314057826521\n",
|
"Thing prediction accuracy: 0.9793861658268438\n",
|
||||||
"Correct thing predictions: 1985, Incorrect thing predictions: 21\n",
|
"Correct thing predictions: 2138, Incorrect thing predictions: 10800\n",
|
||||||
"Property prediction accuracy: 0.9661016949152542\n",
|
"Property prediction accuracy: 0.9651855245075585\n",
|
||||||
"Correct property predictions: 1938, Incorrect property predictions: 12780\n",
|
"Correct property predictions: 2107, Incorrect property predictions: 10831\n",
|
||||||
"total accuracy: 0.9596211365902293\n",
|
"Total accuracy: 0.9496106275767293\n",
|
||||||
"Correct total predictions: 1925, Incorrect total predictions: 81\n"
|
"Correct total predictions: 2073, Incorrect total predictions: 110\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -262,82 +148,54 @@
|
||||||
"answer_property = [item['answer_property'] for item in test_dataset[\"translation\"]]\n",
|
"answer_property = [item['answer_property'] for item in test_dataset[\"translation\"]]\n",
|
||||||
"mdm_list = [item['MDM'] for item in test_dataset[\"translation\"]]\n",
|
"mdm_list = [item['MDM'] for item in test_dataset[\"translation\"]]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"mdm_count = 0\n",
|
"mdm_count = sum([1 for mdm in mdm_list if mdm == \"True\"])\n",
|
||||||
"for i in range(len(mdm_list)):\n",
|
|
||||||
" if(mdm_list[i] == \"True\"):mdm_count = mdm_count + 1 \n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"def correctness_test(input, reference, mdm_list):\n",
|
"def correctness_test(input, reference, mdm_list):\n",
|
||||||
" assert(len(input) == len(reference))\n",
|
" assert len(input) == len(reference)\n",
|
||||||
" correctness_list = []\n",
|
" return [input[i] == reference[i] if mdm_list[i] == \"True\" else False for i in range(len(input))]\n",
|
||||||
" for i in range(len(input)):\n",
|
|
||||||
" if(mdm_list[i] == \"True\"):\n",
|
|
||||||
" correctness_list.append(input[i] == reference[i])\n",
|
|
||||||
" else:correctness_list.append(False)\n",
|
|
||||||
" return correctness_list\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Compare with answer to evaluate correctness\n",
|
|
||||||
"thing_correctness = correctness_test(p_thing_list, answer_thing, mdm_list)\n",
|
"thing_correctness = correctness_test(p_thing_list, answer_thing, mdm_list)\n",
|
||||||
"property_correctness = correctness_test(p_property_list, answer_property, mdm_list)\n",
|
"property_correctness = correctness_test(p_property_list, answer_property, mdm_list)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"correctness_mdm = []\n",
|
"correctness_mdm = [thing_correctness[i] & property_correctness[i] for i in range(len(mdm_list))]\n",
|
||||||
"for i in range(len(mdm_list)):\n",
|
"\n",
|
||||||
" if(thing_correctness[i] & property_correctness[i]):\n",
|
|
||||||
" correctness_mdm.append(True)\n",
|
|
||||||
" else: \n",
|
|
||||||
" correctness_mdm.append(False)\n",
|
|
||||||
" \n",
|
|
||||||
" \n",
|
|
||||||
"# Calculate accuracy\n",
|
|
||||||
"thing_accuracy = sum(thing_correctness) / mdm_count\n",
|
"thing_accuracy = sum(thing_correctness) / mdm_count\n",
|
||||||
"property_accuracy = sum(property_correctness) / mdm_count\n",
|
"property_accuracy = sum(property_correctness) / mdm_count\n",
|
||||||
"total_accuracy = sum(correctness_mdm) / mdm_count\n",
|
"total_accuracy = sum(correctness_mdm) / mdm_count\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Count True/False values\n",
|
|
||||||
"thing_true_count = thing_correctness.count(True)\n",
|
"thing_true_count = thing_correctness.count(True)\n",
|
||||||
"thing_false_count = 0\n",
|
"thing_false_count = thing_correctness.count(False)\n",
|
||||||
"for i in range(len(thing_correctness)):\n",
|
|
||||||
" if mdm_list[i] == \"True\" and thing_correctness[i] == False:\n",
|
|
||||||
" thing_false_count += 1\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"property_true_count = property_correctness.count(True)\n",
|
"property_true_count = property_correctness.count(True)\n",
|
||||||
"property_false_count = property_correctness.count(False)\n",
|
"property_false_count = property_correctness.count(False)\n",
|
||||||
"total_true_count = correctness_mdm.count(True)\n",
|
|
||||||
"total_false_count = mdm_count - correctness_mdm.count(True)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Print results\n",
|
"total_true_count = correctness_mdm.count(True)\n",
|
||||||
|
"total_false_count = mdm_count - total_true_count\n",
|
||||||
|
"\n",
|
||||||
"print(\"Thing prediction accuracy:\", thing_accuracy)\n",
|
"print(\"Thing prediction accuracy:\", thing_accuracy)\n",
|
||||||
"print(f\"Correct thing predictions: {thing_true_count}, Incorrect thing predictions: {thing_false_count}\")\n",
|
"print(f\"Correct thing predictions: {thing_true_count}, Incorrect thing predictions: {thing_false_count}\")\n",
|
||||||
"print(\"Property prediction accuracy:\", property_accuracy)\n",
|
"print(\"Property prediction accuracy:\", property_accuracy)\n",
|
||||||
"print(f\"Correct property predictions: {property_true_count}, Incorrect property predictions: {property_false_count}\")\n",
|
"print(f\"Correct property predictions: {property_true_count}, Incorrect property predictions: {property_false_count}\")\n",
|
||||||
"print(\"total accuracy:\", total_accuracy)\n",
|
"print(\"Total accuracy:\", total_accuracy)\n",
|
||||||
"print(f\"Correct total predictions: {total_true_count}, Incorrect total predictions: {total_false_count}\")\n",
|
"print(f\"Correct total predictions: {total_true_count}, Incorrect total predictions: {total_false_count}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Create a DataFrame with the results\n",
|
"df_pred = pd.DataFrame({\n",
|
||||||
"dict = {\n",
|
|
||||||
" 'p_thing': p_thing_list,\n",
|
" 'p_thing': p_thing_list,\n",
|
||||||
" 'p_property': p_property_list,\n",
|
" 'p_property': p_property_list,\n",
|
||||||
" 'p_thing_correct': thing_correctness,\n",
|
" 'p_thing_correct': thing_correctness,\n",
|
||||||
" 'p_property_correct': property_correctness\n",
|
" 'p_property_correct': property_correctness\n",
|
||||||
"}\n",
|
"})\n",
|
||||||
"\n",
|
"\n",
|
||||||
"df_pred = pd.DataFrame(dict)\n",
|
|
||||||
"\n",
|
|
||||||
"# Read the mode from the JSON file\n",
|
|
||||||
"with open(\"mode.json\", \"r\") as json_file:\n",
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||||
" mode_dict = json.load(json_file)\n",
|
" mode_dict = json.load(json_file)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Add the model key to the dictionary\n",
|
|
||||||
"mode_dict[\"model\"] = model_name\n",
|
"mode_dict[\"model\"] = model_name\n",
|
||||||
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save the updated dictionary back to the JSON file\n",
|
|
||||||
"with open(\"mode.json\", \"w\") as json_file:\n",
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||||
" json.dump(mode_dict, json_file)\n",
|
" json.dump(mode_dict, json_file)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
|
||||||
"# Check if the file exists and is not empty\n",
|
|
||||||
"if os.path.exists(\"results.json\") and os.path.getsize(\"results.json\") > 0:\n",
|
"if os.path.exists(\"results.json\") and os.path.getsize(\"results.json\") > 0:\n",
|
||||||
" # Read the existing results.json file\n",
|
|
||||||
" with open(\"results.json\", \"r\") as json_file:\n",
|
" with open(\"results.json\", \"r\") as json_file:\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" results_dict = json.load(json_file)\n",
|
" results_dict = json.load(json_file)\n",
|
||||||
|
@ -346,9 +204,7 @@
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" results_dict = {}\n",
|
" results_dict = {}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Add the new model_checkpoint key with the accuracy values as an object\n",
|
"model_key = model_checkpoint\n",
|
||||||
"\n",
|
|
||||||
"model_key = model_checkpoint \n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"results_dict[model_key] = {\n",
|
"results_dict[model_key] = {\n",
|
||||||
" \"thing_accuracy\": thing_accuracy,\n",
|
" \"thing_accuracy\": thing_accuracy,\n",
|
||||||
|
@ -359,31 +215,30 @@
|
||||||
" \"property_false\": property_false_count,\n",
|
" \"property_false\": property_false_count,\n",
|
||||||
" \"total_accuracy\": total_accuracy,\n",
|
" \"total_accuracy\": total_accuracy,\n",
|
||||||
" \"total_true\": total_true_count,\n",
|
" \"total_true\": total_true_count,\n",
|
||||||
" \"total_false\": total_false_count \n",
|
" \"total_false\": total_false_count\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save the updated dictionary back to the results.json file\n",
|
|
||||||
"with open(\"results.json\", \"w\") as json_file:\n",
|
"with open(\"results.json\", \"w\") as json_file:\n",
|
||||||
" json.dump(results_dict, json_file, indent=4)"
|
" json.dump(results_dict, json_file, indent=4)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Updated data saved to ../0.result/1/test_p.csv\n"
|
"Updated data saved to ../0.result/5/test_p.csv\n",
|
||||||
|
"Updated data saved to 0.dresult/td_unit/google/t5-efficient-tiny/5/test_p.csv\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Create a DataFrame with the results\n",
|
|
||||||
"df_pred = pd.DataFrame({\n",
|
"df_pred = pd.DataFrame({\n",
|
||||||
" 'p_thing': p_thing_list,\n",
|
" 'p_thing': p_thing_list,\n",
|
||||||
" 'p_property': p_property_list,\n",
|
" 'p_property': p_property_list,\n",
|
||||||
|
@ -391,7 +246,6 @@
|
||||||
" 'p_property_correct': property_correctness,\n",
|
" 'p_property_correct': property_correctness,\n",
|
||||||
"})\n",
|
"})\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Merge predictions with the original DataFrame (df_org)\n",
|
|
||||||
"df_org['p_thing'] = df_pred['p_thing']\n",
|
"df_org['p_thing'] = df_pred['p_thing']\n",
|
||||||
"df_org['p_property'] = df_pred['p_property']\n",
|
"df_org['p_property'] = df_pred['p_property']\n",
|
||||||
"df_org['p_thing_correct'] = df_pred['p_thing_correct']\n",
|
"df_org['p_thing_correct'] = df_pred['p_thing_correct']\n",
|
||||||
|
@ -404,22 +258,20 @@
|
||||||
"df_org['p_pattern'] = df_org['p_thing'].str.replace(r'\\d', '#', regex=True) + \" \" + df_org['p_property'].str.replace(r'\\d', '#', regex=True)\n",
|
"df_org['p_pattern'] = df_org['p_thing'].str.replace(r'\\d', '#', regex=True) + \" \" + df_org['p_property'].str.replace(r'\\d', '#', regex=True)\n",
|
||||||
"df_master['master_pattern'] = df_master['thing'] + \" \" + df_master['property']\n",
|
"df_master['master_pattern'] = df_master['thing'] + \" \" + df_master['property']\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Create a set of unique patterns from master for fast lookup\n",
|
|
||||||
"master_patterns = set(df_master['master_pattern'])\n",
|
"master_patterns = set(df_master['master_pattern'])\n",
|
||||||
"df_org['p_MDM'] = df_org['p_pattern'].apply(lambda x: x in master_patterns)\n",
|
"df_org['p_MDM'] = df_org['p_pattern'].apply(lambda x: x in master_patterns)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
|
||||||
"output_path = f\"../0.result/{fold_group}/test_p.csv\"\n",
|
"output_path = f\"../0.result/{fold_group}/test_p.csv\"\n",
|
||||||
"debug_output_path = f\"0.dresult/{fold_group}/test_p.csv\"\n",
|
"debug_output_path = f\"0.dresult/{mode}/{model_name}/{fold_group}/test_p.csv\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 폴더가 없으면 생성\n",
|
|
||||||
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||||||
"df_org.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
"df_org.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.makedirs(os.path.dirname(debug_output_path), exist_ok=True)\n",
|
"os.makedirs(os.path.dirname(debug_output_path), exist_ok=True)\n",
|
||||||
"df_org.to_csv(debug_output_path, index=False, encoding='utf-8-sig')\n",
|
"df_org.to_csv(debug_output_path, index=False, encoding='utf-8-sig')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"Updated data saved to {output_path}\")"
|
"print(f\"Updated data saved to {output_path}\")\n",
|
||||||
|
"print(f\"Updated data saved to {debug_output_path}\")"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Group 1 Recall: 0.947941\n",
|
||||||
|
"Group 2 Recall: 0.902804\n",
|
||||||
|
"Group 3 Recall: 0.970884\n",
|
||||||
|
"Group 4 Recall: 0.965271\n",
|
||||||
|
"Group 5 Recall: 0.949611\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"# mode, model_name, fold_group 설정\n",
|
||||||
|
"mode = 'td_unit' # 원하는 모드를 설정하세요\n",
|
||||||
|
"model_name = 'google/t5-efficient-tiny' # 모델 이름을 설정하세요\n",
|
||||||
|
"recall_by_group = {}\n",
|
||||||
|
"\n",
|
||||||
|
"# 그룹 1부터 5까지 처리\n",
|
||||||
|
"for group in range(1, 6):\n",
|
||||||
|
" # CSV 파일 경로 설정 (model_name 포함)\n",
|
||||||
|
" debug_output_path = f\"0.dresult/{mode}/{model_name}/{group}/test_p.csv\"\n",
|
||||||
|
" \n",
|
||||||
|
" # CSV 파일 로드\n",
|
||||||
|
" try:\n",
|
||||||
|
" df = pd.read_csv(debug_output_path)\n",
|
||||||
|
" except FileNotFoundError:\n",
|
||||||
|
" print(f\"File not found: {debug_output_path}\")\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" # 1. MDM이 True인 항목만 필터\n",
|
||||||
|
" filtered_df = df[df['MDM'] == True].copy()\n",
|
||||||
|
"\n",
|
||||||
|
" # 2. p_thing과 p_property가 thing과 property와 같으면 TP로 설정 (loc 사용)\n",
|
||||||
|
" filtered_df.loc[:, 'TP'] = (filtered_df['p_thing'] == filtered_df['thing']) & (filtered_df['p_property'] == filtered_df['property'])\n",
|
||||||
|
"\n",
|
||||||
|
" # 3. TP 갯수와 전체 MDM 갯수로 Recall 계산\n",
|
||||||
|
" tp_count = filtered_df['TP'].sum()\n",
|
||||||
|
" total_count = len(filtered_df)\n",
|
||||||
|
"\n",
|
||||||
|
" # Recall 계산\n",
|
||||||
|
" if total_count > 0:\n",
|
||||||
|
" recall = tp_count / total_count\n",
|
||||||
|
" else:\n",
|
||||||
|
" recall = 0\n",
|
||||||
|
"\n",
|
||||||
|
" # 그룹별 Recall 저장\n",
|
||||||
|
" recall_by_group[group] = recall\n",
|
||||||
|
"\n",
|
||||||
|
"# Recall 출력\n",
|
||||||
|
"for group, recall in recall_by_group.items():\n",
|
||||||
|
" print(f\"Group {group} Recall: {recall:.6f}\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torch",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue