448 lines
17 KiB
Plaintext
448 lines
17 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Goal: end to end inference and evaluation\n",
|
|
"\n",
|
|
"given a csv, make predictions and evaluate predictions, then return results in a csv"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The mode has been set to: tn_td_unit t5-base\n",
|
|
"Using model checkpoint: train_1_t5-base_tn_td_unit_80/checkpoint-3840\n",
|
|
"Columns in df_org:\n",
|
|
"['thing', 'property', 'ships_idx', 'tag_name', 'tag_description', 'signal_type', 'min', 'max', 'unit', 'data_type', 'thing_pattern', 'property_pattern', 'pattern', 'MDM', 'org_tag_description']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"\n",
|
|
"# Read the mode from the JSON file\n",
|
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
|
" mode_dict = json.load(json_file)\n",
|
|
"\n",
|
|
"\n",
|
|
"# Set the mode variable from the JSON content\n",
|
|
"mode = mode_dict.get(\"mode\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
"model_name = mode_dict.get(\"model\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
"train_epochs = mode_dict.get(\"train_epochs\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
"fold_group = mode_dict.get(\"fold_group\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
"\n",
|
|
"print(f\"The mode has been set to: {mode} {model_name}\")\n",
|
|
"\n",
|
|
"# Define the base directory where checkpoints are stored\n",
|
|
"base_dir = f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\"\n",
|
|
"\n",
|
|
"# List all subdirectories in the base directory\n",
|
|
"subdirectories = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]\n",
|
|
"\n",
|
|
"# Filter for checkpoint directories that match the pattern \"checkpoint-\"\n",
|
|
"checkpoints = [d for d in subdirectories if d.startswith(\"checkpoint-\")]\n",
|
|
"\n",
|
|
"# Select the latest checkpoint (the one with the highest number)\n",
|
|
"if checkpoints:\n",
|
|
" latest_checkpoint = checkpoints[0]\n",
|
|
" model_checkpoint = os.path.join(base_dir, latest_checkpoint)\n",
|
|
" print(f\"Using model checkpoint: {model_checkpoint}\")\n",
|
|
"else:\n",
|
|
" print(\"No checkpoints were found.\")\n",
|
|
" model_checkpoint = None # Handle this case as needed\n",
|
|
"\n",
|
|
"# Load the data\n",
|
|
"data_path = f\"../../data_preprocess/dataset/{fold_group}/test.csv\" # Adjust the CSV file path as necessary\n",
|
|
"\n",
|
|
"try:\n",
|
|
" df = pd.read_csv(data_path)\n",
|
|
"except UnicodeDecodeError:\n",
|
|
" df = pd.read_csv(data_path, encoding='ISO-8859-1')\n",
|
|
"\n",
|
|
"\n",
|
|
"# Drop rows where 'tag_description' is NaN and reset the index\n",
|
|
"df = df.dropna(subset=['tag_description']).reset_index(drop=True)\n",
|
|
"\n",
|
|
"# Preserve df_org\n",
|
|
"df_org = df.copy()\n",
|
|
"\n",
|
|
"# Print the column names of df_org\n",
|
|
"print(\"Columns in df_org:\")\n",
|
|
"print(df_org.columns.tolist())\n",
|
|
"\n",
|
|
"selected_columns = ['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']\n",
|
|
"df[selected_columns] = df[selected_columns].astype(\"string\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The test_dataset contains 14718 items.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from datasets import Dataset\n",
|
|
"\n",
|
|
"def process_df(df, mode='only_td'):\n",
|
|
" output_list = []\n",
|
|
" for _, row in df.iterrows():\n",
|
|
" try:\n",
|
|
" if mode == 'only_td':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
|
" elif mode == 'tn_td':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
|
" elif mode == 'tn_td_min_max':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
|
" elif mode == 'td_min_max':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\" \n",
|
|
" elif mode == 'td_unit':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
|
" elif mode == 'tn_td_unit':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
|
" else:\n",
|
|
" raise ValueError(\"Invalid mode specified\")\n",
|
|
"\n",
|
|
" output_list.append({\n",
|
|
" 'translation': {\n",
|
|
" 'ships_idx': row['ships_idx'],\n",
|
|
" 'input': input_str,\n",
|
|
" 'thing_property': f\"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>\",\n",
|
|
" 'answer_thing': f\"{row['thing']}\",\n",
|
|
" 'answer_property': f\"{row['property']}\",\n",
|
|
" 'MDM': f\"{row['MDM']}\",\n",
|
|
" }\n",
|
|
" })\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing row: {row}\")\n",
|
|
" print(f\"Exception: {e}\")\n",
|
|
" return output_list\n",
|
|
"\n",
|
|
"\n",
|
|
"# Process the DataFrame\n",
|
|
"processed_data = process_df(df, mode=mode)\n",
|
|
"\n",
|
|
"# Create a Dataset object\n",
|
|
"test_dataset = Dataset.from_list(processed_data)\n",
|
|
"\n",
|
|
"# Print the number of items in the dataset\n",
|
|
"print(f\"The test_dataset contains {len(test_dataset)} items.\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers.pipelines.pt_utils import KeyDataset\n",
|
|
"from transformers import pipeline\n",
|
|
"from tqdm import tqdm\n",
|
|
"import os\n",
|
|
"from transformers import AutoTokenizer\n",
|
|
"\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors=\"pt\")\n",
|
|
"# Define additional special tokens\n",
|
|
"# additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\"]\n",
|
|
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
|
"\n",
|
|
"# Add the additional special tokens to the tokenizer\n",
|
|
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
|
"# tokenizer.add_special_tokens({'sep_token': \"<SEP>\"})\n",
|
|
"\n",
|
|
"\n",
|
|
"pipe = pipeline(\"translation_XX_to_YY\", model=model_checkpoint, tokenizer=tokenizer, return_tensors=True, max_length=128, device=0)\n",
|
|
"\n",
|
|
"# check what token-ids the special tokens are\n",
|
|
"# tokenizer.encode(\"<THING_START><THING_END><PROPERTY_START><PROPERTY_END>\")\n",
|
|
"\n",
|
|
"def extract_seq(tokens, start_value, end_value):\n",
|
|
" if start_value not in tokens or end_value not in tokens:\n",
|
|
" return None # Or handle this case according to your requirements\n",
|
|
" start_id = tokens.index(start_value)\n",
|
|
" end_id = tokens.index(end_value)\n",
|
|
"\n",
|
|
" return tokens[start_id+1:end_id]\n",
|
|
"\n",
|
|
"# problem, what if end tokens are not in?\n",
|
|
"def process_tensor_output(output):\n",
|
|
" tokens = output[0]['translation_token_ids'].tolist()\n",
|
|
" thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>\n",
|
|
" property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>\n",
|
|
" p_thing = None\n",
|
|
" p_property = None\n",
|
|
" if (thing_seq is not None):\n",
|
|
" p_thing = tokenizer.decode(thing_seq)\n",
|
|
" if (property_seq is not None):\n",
|
|
" p_property = tokenizer.decode(property_seq)\n",
|
|
" return p_thing, p_property"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"making inference on test set\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"14718it [00:44, 330.24it/s] "
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"inference done\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"p_thing_list = []\n",
|
|
"p_property_list = []\n",
|
|
"print(\"making inference on test set\")\n",
|
|
"for out in tqdm(pipe(KeyDataset(test_dataset[\"translation\"], \"input\"), batch_size=256)):\n",
|
|
" p_thing, p_property = process_tensor_output(out)\n",
|
|
" p_thing_list.append(p_thing)\n",
|
|
" p_property_list.append(p_property)\n",
|
|
"print(\"inference done\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Thing prediction accuracy: 0.9895314057826521\n",
|
|
"Correct thing predictions: 1985, Incorrect thing predictions: 21\n",
|
|
"Property prediction accuracy: 0.9661016949152542\n",
|
|
"Correct property predictions: 1938, Incorrect property predictions: 12780\n",
|
|
"total accuracy: 0.9596211365902293\n",
|
|
"Correct total predictions: 1925, Incorrect total predictions: 81\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"answer_thing = [item['answer_thing'] for item in test_dataset[\"translation\"]]\n",
|
|
"answer_property = [item['answer_property'] for item in test_dataset[\"translation\"]]\n",
|
|
"mdm_list = [item['MDM'] for item in test_dataset[\"translation\"]]\n",
|
|
"\n",
|
|
"mdm_count = 0\n",
|
|
"for i in range(len(mdm_list)):\n",
|
|
" if(mdm_list[i] == \"True\"):mdm_count = mdm_count + 1 \n",
|
|
"\n",
|
|
"def correctness_test(input, reference, mdm_list):\n",
|
|
" assert(len(input) == len(reference))\n",
|
|
" correctness_list = []\n",
|
|
" for i in range(len(input)):\n",
|
|
" if(mdm_list[i] == \"True\"):\n",
|
|
" correctness_list.append(input[i] == reference[i])\n",
|
|
" else:correctness_list.append(False)\n",
|
|
" return correctness_list\n",
|
|
"\n",
|
|
"# Compare with answer to evaluate correctness\n",
|
|
"thing_correctness = correctness_test(p_thing_list, answer_thing, mdm_list)\n",
|
|
"property_correctness = correctness_test(p_property_list, answer_property, mdm_list)\n",
|
|
"\n",
|
|
"correctness_mdm = []\n",
|
|
"for i in range(len(mdm_list)):\n",
|
|
" if(thing_correctness[i] & property_correctness[i]):\n",
|
|
" correctness_mdm.append(True)\n",
|
|
" else: \n",
|
|
" correctness_mdm.append(False)\n",
|
|
" \n",
|
|
" \n",
|
|
"# Calculate accuracy\n",
|
|
"thing_accuracy = sum(thing_correctness) / mdm_count\n",
|
|
"property_accuracy = sum(property_correctness) / mdm_count\n",
|
|
"total_accuracy = sum(correctness_mdm) / mdm_count\n",
|
|
"\n",
|
|
"# Count True/False values\n",
|
|
"thing_true_count = thing_correctness.count(True)\n",
|
|
"thing_false_count = 0\n",
|
|
"for i in range(len(thing_correctness)):\n",
|
|
" if mdm_list[i] == \"True\" and thing_correctness[i] == False:\n",
|
|
" thing_false_count += 1\n",
|
|
"\n",
|
|
"property_true_count = property_correctness.count(True)\n",
|
|
"property_false_count = property_correctness.count(False)\n",
|
|
"total_true_count = correctness_mdm.count(True)\n",
|
|
"total_false_count = mdm_count - correctness_mdm.count(True)\n",
|
|
"\n",
|
|
"# Print results\n",
|
|
"print(\"Thing prediction accuracy:\", thing_accuracy)\n",
|
|
"print(f\"Correct thing predictions: {thing_true_count}, Incorrect thing predictions: {thing_false_count}\")\n",
|
|
"print(\"Property prediction accuracy:\", property_accuracy)\n",
|
|
"print(f\"Correct property predictions: {property_true_count}, Incorrect property predictions: {property_false_count}\")\n",
|
|
"print(\"total accuracy:\", total_accuracy)\n",
|
|
"print(f\"Correct total predictions: {total_true_count}, Incorrect total predictions: {total_false_count}\")\n",
|
|
"\n",
|
|
"# Create a DataFrame with the results\n",
|
|
"dict = {\n",
|
|
" 'p_thing': p_thing_list,\n",
|
|
" 'p_property': p_property_list,\n",
|
|
" 'p_thing_correct': thing_correctness,\n",
|
|
" 'p_property_correct': property_correctness\n",
|
|
"}\n",
|
|
"\n",
|
|
"df_pred = pd.DataFrame(dict)\n",
|
|
"\n",
|
|
"# Read the mode from the JSON file\n",
|
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
|
" mode_dict = json.load(json_file)\n",
|
|
"\n",
|
|
"# Add the model key to the dictionary\n",
|
|
"mode_dict[\"model\"] = model_name\n",
|
|
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
|
"\n",
|
|
"# Save the updated dictionary back to the JSON file\n",
|
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
|
" json.dump(mode_dict, json_file)\n",
|
|
"\n",
|
|
"\n",
|
|
"# Check if the file exists and is not empty\n",
|
|
"if os.path.exists(\"results.json\") and os.path.getsize(\"results.json\") > 0:\n",
|
|
" # Read the existing results.json file\n",
|
|
" with open(\"results.json\", \"r\") as json_file:\n",
|
|
" try:\n",
|
|
" results_dict = json.load(json_file)\n",
|
|
" except json.JSONDecodeError:\n",
|
|
" results_dict = {}\n",
|
|
"else:\n",
|
|
" results_dict = {}\n",
|
|
"\n",
|
|
"# Add the new model_checkpoint key with the accuracy values as an object\n",
|
|
"\n",
|
|
"model_key = model_checkpoint \n",
|
|
"\n",
|
|
"results_dict[model_key] = {\n",
|
|
" \"thing_accuracy\": thing_accuracy,\n",
|
|
" \"thing_true\": thing_true_count,\n",
|
|
" \"thing_false\": thing_false_count,\n",
|
|
" \"property_accuracy\": property_accuracy,\n",
|
|
" \"property_true\": property_true_count,\n",
|
|
" \"property_false\": property_false_count,\n",
|
|
" \"total_accuracy\": total_accuracy,\n",
|
|
" \"total_true\": total_true_count,\n",
|
|
" \"total_false\": total_false_count \n",
|
|
"}\n",
|
|
"\n",
|
|
"# Save the updated dictionary back to the results.json file\n",
|
|
"with open(\"results.json\", \"w\") as json_file:\n",
|
|
" json.dump(results_dict, json_file, indent=4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Updated data saved to ../0.result/1/test_p.csv\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"# Create a DataFrame with the results\n",
|
|
"df_pred = pd.DataFrame({\n",
|
|
" 'p_thing': p_thing_list,\n",
|
|
" 'p_property': p_property_list,\n",
|
|
" 'p_thing_correct': thing_correctness,\n",
|
|
" 'p_property_correct': property_correctness,\n",
|
|
"})\n",
|
|
"\n",
|
|
"# Merge predictions with the original DataFrame (df_org)\n",
|
|
"df_org['p_thing'] = df_pred['p_thing']\n",
|
|
"df_org['p_property'] = df_pred['p_property']\n",
|
|
"df_org['p_thing_correct'] = df_pred['p_thing_correct']\n",
|
|
"df_org['p_property_correct'] = df_pred['p_property_correct']\n",
|
|
"df_org['p_correct'] = df_pred['p_thing_correct'] & df_org['p_property_correct']\n",
|
|
"\n",
|
|
"df_master = pd.read_csv('../../data_import/data_model_master_export.csv')\n",
|
|
"\n",
|
|
"df_org['pattern'] = df_org['thing'].str.replace(r'\\d', '#', regex=True) + \" \" + df_org['property'].str.replace(r'\\d', '#', regex=True)\n",
|
|
"df_org['p_pattern'] = df_org['p_thing'].str.replace(r'\\d', '#', regex=True) + \" \" + df_org['p_property'].str.replace(r'\\d', '#', regex=True)\n",
|
|
"df_master['master_pattern'] = df_master['thing'] + \" \" + df_master['property']\n",
|
|
"\n",
|
|
"# Create a set of unique patterns from master for fast lookup\n",
|
|
"master_patterns = set(df_master['master_pattern'])\n",
|
|
"df_org['p_MDM'] = df_org['p_pattern'].apply(lambda x: x in master_patterns)\n",
|
|
"\n",
|
|
"\n",
|
|
"output_path = f\"../0.result/{fold_group}/test_p.csv\"\n",
|
|
"debug_output_path = f\"0.dresult/{fold_group}/test_p.csv\"\n",
|
|
"\n",
|
|
"# 폴더가 없으면 생성\n",
|
|
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
|
"df_org.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
|
"\n",
|
|
"os.makedirs(os.path.dirname(debug_output_path), exist_ok=True)\n",
|
|
"df_org.to_csv(debug_output_path, index=False, encoding='utf-8-sig')\n",
|
|
"\n",
|
|
"print(f\"Updated data saved to {output_path}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|