[TASK] init
This commit is contained in:
parent
3841867c4b
commit
3d2266cf65
|
@ -0,0 +1,59 @@
|
|||
import psycopg2
|
||||
import pandas as pd
|
||||
|
||||
# Function to read the db connection info
|
||||
def read_db_connection_info(filename="db_connection_info.txt"):
|
||||
connection_info = {}
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
for line in file:
|
||||
key, value = line.strip().split('=')
|
||||
connection_info[key] = value
|
||||
except Exception as e:
|
||||
print(f"Failed to read database connection info: {e}")
|
||||
raise
|
||||
return connection_info
|
||||
|
||||
# Load the connection info
|
||||
connection_info = read_db_connection_info()
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
conn = psycopg2.connect(
|
||||
host=connection_info["host"],
|
||||
user=connection_info["user"],
|
||||
password=connection_info["password"],
|
||||
dbname=connection_info["database"],
|
||||
port=connection_info["port"]
|
||||
)
|
||||
# This ensures that resources are cleaned up properly
|
||||
with conn:
|
||||
with conn.cursor() as cursor:
|
||||
# Export data_mapping table
|
||||
query_mapping = """
|
||||
SELECT * FROM data_mapping
|
||||
WHERE ships_idx BETWEEN 1000 AND 1999
|
||||
"""
|
||||
cursor.execute(query_mapping)
|
||||
results_mapping = cursor.fetchall()
|
||||
columns_mapping = [desc[0] for desc in cursor.description]
|
||||
df_mapping = pd.DataFrame(results_mapping, columns=columns_mapping)
|
||||
df_mapping.to_csv('data_import/data_mapping.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
# Export data_master_model table
|
||||
query_master = """
|
||||
SELECT * FROM data_model_master
|
||||
"""
|
||||
cursor.execute(query_master)
|
||||
results_master = cursor.fetchall()
|
||||
columns_master = [desc[0] for desc in cursor.description]
|
||||
df_master = pd.DataFrame(results_master, columns=columns_master)
|
||||
df_master.to_csv('data_import/data_model_master_export.csv', index=False, encoding='utf-8-sig')
|
||||
|
||||
print("Data exported successfully to 'data_import/data_mapping.csv' and 'data_import/data_model_master_export.csv'")
|
||||
|
||||
except (Exception, psycopg2.DatabaseError) as error:
|
||||
print(f"An error occurred: {error}")
|
||||
finally:
|
||||
if conn is not None:
|
||||
conn.close()
|
|
@ -0,0 +1,38 @@
|
|||
import pandas as pd
|
||||
import re
|
||||
|
||||
# Load the data_mapping CSV file
|
||||
data_mapping_file_path = 'data_import/data_mapping.csv' # Adjust this path to your actual file location
|
||||
data_mapping = pd.read_csv(data_mapping_file_path, dtype=str)
|
||||
df_master = pd.read_csv('data_import/data_model_master_export.csv')
|
||||
|
||||
# Generate patterns
|
||||
data_mapping['thing_pattern'] = data_mapping['thing'].str.replace(r'\d', '#', regex=True)
|
||||
data_mapping['property_pattern'] = data_mapping['property'].str.replace(r'\d', '#', regex=True)
|
||||
data_mapping['pattern'] = data_mapping['thing_pattern'] + " " + data_mapping['property_pattern']
|
||||
df_master['master_pattern'] = df_master['thing'] + " " + df_master['property']
|
||||
|
||||
# Create a set of unique patterns from master for fast lookup
|
||||
master_patterns = set(df_master['master_pattern'])
|
||||
|
||||
# Check each pattern in data_mapping if it exists in df_master and assign the "MDM" field
|
||||
data_mapping['MDM'] = data_mapping['pattern'].apply(lambda x: x in master_patterns)
|
||||
|
||||
# Remove specified fields
|
||||
fields_to_remove = ['equip_type_code', 'tx_period', 'tx_type', 'on_change_yn', 'scaling_const', 'description', 'updated_time', 'status_code', 'is_timeout']
|
||||
merged_data = data_mapping.drop(columns=fields_to_remove)
|
||||
|
||||
# Save the updated DataFrame to a new CSV file
|
||||
output_file_path = 'data_import/raw_data.csv'
|
||||
merged_data.to_csv(output_file_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
print(f"Updated data saved to {output_file_path}")
|
||||
|
||||
# Filter the DataFrame where MDM is TRUE
|
||||
data_mapping_mdm_true = merged_data[merged_data['MDM']]
|
||||
|
||||
# Save the filtered DataFrame to a new CSV file
|
||||
mdm_true_output_file_path = 'data_import/data_mapping_mdm.csv'
|
||||
data_mapping_mdm_true.to_csv(mdm_true_output_file_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
print(f"MDM TRUE data saved to {mdm_true_output_file_path}")
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,9 @@
|
|||
import shutil
|
||||
|
||||
source_file = 'data_import/raw_data.csv'
|
||||
|
||||
destination_file = 'data_preprocess/preprocessed_data.csv'
|
||||
|
||||
shutil.copy(source_file, destination_file)
|
||||
|
||||
print(f"File copied from {source_file} to {destination_file}")
|
|
@ -0,0 +1,133 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Changes made in ships_idx 1000: 251\n",
|
||||
"Changes made in ships_idx 1001: 54\n",
|
||||
"Changes made in ships_idx 1002: 46\n",
|
||||
"Changes made in ships_idx 1003: 162\n",
|
||||
"Changes made in ships_idx 1004: 8\n",
|
||||
"Changes made in ships_idx 1005: 18\n",
|
||||
"Changes made in ships_idx 1008: 22\n",
|
||||
"Changes made in ships_idx 1009: 5\n",
|
||||
"Changes made in ships_idx 1010: 135\n",
|
||||
"Changes made in ships_idx 1011: 46\n",
|
||||
"Changes made in ships_idx 1012: 2\n",
|
||||
"Changes made in ships_idx 1013: 130\n",
|
||||
"Changes made in ships_idx 1014: 46\n",
|
||||
"Changes made in ships_idx 1015: 147\n",
|
||||
"Changes made in ships_idx 1016: 191\n",
|
||||
"Changes made in ships_idx 1017: 111\n",
|
||||
"Changes made in ships_idx 1018: 682\n",
|
||||
"Changes made in ships_idx 1019: 2\n",
|
||||
"Changes made in ships_idx 1020: 10\n",
|
||||
"Changes made in ships_idx 1021: 2\n",
|
||||
"Changes made in ships_idx 1022: 7\n",
|
||||
"Changes made in ships_idx 1023: 7\n",
|
||||
"Changes made in ships_idx 1024: 136\n",
|
||||
"Changes made in ships_idx 1025: 10\n",
|
||||
"Changes made in ships_idx 1026: 6\n",
|
||||
"Changes made in ships_idx 1027: 6\n",
|
||||
"Changes made in ships_idx 1028: 6\n",
|
||||
"Changes made in ships_idx 1029: 132\n",
|
||||
"Changes made in ships_idx 1030: 86\n",
|
||||
"Changes made in ships_idx 1031: 55\n",
|
||||
"Changes made in ships_idx 1032: 225\n",
|
||||
"Changes made in ships_idx 1033: 147\n",
|
||||
"Changes made in ships_idx 1035: 132\n",
|
||||
"Changes made in ships_idx 1036: 12\n",
|
||||
"Changes made in ships_idx 1037: 3\n",
|
||||
"Changes made in ships_idx 1038: 8\n",
|
||||
"Changes made in ships_idx 1039: 232\n",
|
||||
"Changes made in ships_idx 1042: 20\n",
|
||||
"Changes made in ships_idx 1043: 154\n",
|
||||
"Changes made in ships_idx 1044: 121\n",
|
||||
"Changes made in ships_idx 1045: 255\n",
|
||||
"Changes made in ships_idx 1046: 6\n",
|
||||
"Changes made in ships_idx 1047: 12\n",
|
||||
"Changes made in ships_idx 1048: 82\n",
|
||||
"Changes made in ships_idx 1049: 912\n",
|
||||
"Changes made in ships_idx 1050: 46\n",
|
||||
"Changes made in ships_idx 1051: 63\n",
|
||||
"Total number of changes made: 4951\n",
|
||||
"Updated data saved to raw_data_add_tag.csv\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# Load the preprocessed data CSV file\n",
|
||||
"file_path = '../../data_import/raw_data.csv' # Adjust this path to your actual file location\n",
|
||||
"data = pd.read_csv(file_path, dtype=str)\n",
|
||||
"\n",
|
||||
"# Initialize a counter for the total number of changes\n",
|
||||
"total_changes = 0\n",
|
||||
"\n",
|
||||
"# Initialize a dictionary to count changes per ships_idx\n",
|
||||
"ships_idx_changes = {}\n",
|
||||
"\n",
|
||||
"# Process each group by ships_idx\n",
|
||||
"for ships_idx, group in data.groupby('ships_idx'):\n",
|
||||
" # Find duplicated tag_descriptions within the group\n",
|
||||
" duplicated_descriptions = group['tag_description'].duplicated(keep=False)\n",
|
||||
" \n",
|
||||
" # Count how many tag_descriptions are duplicated within this ships_idx\n",
|
||||
" num_changes = duplicated_descriptions.sum()\n",
|
||||
"\n",
|
||||
" # If there are any duplicates\n",
|
||||
" if num_changes > 0:\n",
|
||||
" # Increment the total changes count\n",
|
||||
" total_changes += num_changes\n",
|
||||
" \n",
|
||||
" # Record the number of changes for this ships_idx\n",
|
||||
" ships_idx_changes[ships_idx] = num_changes\n",
|
||||
"\n",
|
||||
" # Apply the concatenation of tag_name to tag_description for duplicates\n",
|
||||
" data.loc[duplicated_descriptions & (data['ships_idx'] == ships_idx), 'tag_description'] = \\\n",
|
||||
" data['tag_name'] + ' ' + data['tag_description']\n",
|
||||
"\n",
|
||||
"# Output the changes per ships_idx\n",
|
||||
"for ships_idx, count in ships_idx_changes.items():\n",
|
||||
" print(f\"Changes made in ships_idx {ships_idx}: {count}\")\n",
|
||||
"\n",
|
||||
"# Output the total number of changes\n",
|
||||
"print(f\"Total number of changes made: {total_changes}\")\n",
|
||||
"\n",
|
||||
"# Optionally, save the updated DataFrame back to a CSV\n",
|
||||
"output_file_path = 'raw_data_add_tag.csv'\n",
|
||||
"data.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"print(f\"Updated data saved to {output_file_path}\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Updated data saved to raw_data_s.csv\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"# Load the data_mapping CSV file\n",
|
||||
"data_mapping_file_path = '../../data_import/raw_data.csv' # Adjust this path to your actual file location\n",
|
||||
"# data_mapping_file_path = 'raw_data_add_tag.csv' # Adjust this path to your actual file location\n",
|
||||
"data_mapping = pd.read_csv(data_mapping_file_path, dtype=str)\n",
|
||||
"\n",
|
||||
"# Backup the original tag_description\n",
|
||||
"data_mapping['org_tag_description'] = data_mapping['tag_description']\n",
|
||||
"\n",
|
||||
"# Ensure all values in the 'tag_description' column are strings\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].fillna('').astype(str)\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].str.replace(r'[()]', ' ', regex=True)\n",
|
||||
"\n",
|
||||
"# Function to find tokens containing numbers\n",
|
||||
"def find_tokens_with_numbers(description):\n",
|
||||
" tokens = description.split() # Tokenize by spaces\n",
|
||||
" number_tokens = [token for token in tokens if re.search(r'\\d', token)]\n",
|
||||
" return number_tokens\n",
|
||||
"\n",
|
||||
"# Function to process tokens\n",
|
||||
"def process_token(token):\n",
|
||||
" # Step 1: Replace '_' or '-' adjacent to numbers with spaces\n",
|
||||
" token = re.sub(r'(_|-)(?=\\d)', ' ', token)\n",
|
||||
" token = re.sub(r'(?<=\\d)(_|-)', ' ', token)\n",
|
||||
"\n",
|
||||
" # Step 2: Insert spaces between letters and numbers where no separator exists\n",
|
||||
" token = re.sub(r'([A-Za-z])(\\d+)', r'\\1 \\2', token)\n",
|
||||
" token = re.sub(r'(\\d+)([A-Za-z])', r'\\1 \\2', token)\n",
|
||||
"\n",
|
||||
" # Step 3: Handle cases like \"NO.1\" or \"No.1\" to become \"No. 1\"\n",
|
||||
" token = re.sub(r'([A-Za-z]+)\\.(\\d+)', r'\\1. \\2', token)\n",
|
||||
"\n",
|
||||
" # Clean multiple spaces and strip\n",
|
||||
" token = re.sub(r'\\s+', ' ', token).strip()\n",
|
||||
" return token\n",
|
||||
"\n",
|
||||
"# Apply the process to each row in the 'tag_description' column\n",
|
||||
"for index, row in data_mapping.iterrows():\n",
|
||||
" original_description = row['tag_description']\n",
|
||||
" number_tokens = find_tokens_with_numbers(original_description)\n",
|
||||
"\n",
|
||||
" # Process each token containing numbers\n",
|
||||
" processed_tokens = [process_token(token) for token in number_tokens]\n",
|
||||
"\n",
|
||||
" # Replace the original tokens with processed tokens in the tag_description\n",
|
||||
" new_description = original_description\n",
|
||||
" for original_token, processed_token in zip(number_tokens, processed_tokens):\n",
|
||||
" new_description = new_description.replace(original_token, processed_token)\n",
|
||||
"\n",
|
||||
" # Update the data_mapping with the modified description\n",
|
||||
" data_mapping.at[index, 'tag_description'] = new_description\n",
|
||||
"\n",
|
||||
"# Save the updated data_mapping to a new CSV file\n",
|
||||
"output_file_path = 'raw_data_s.csv'\n",
|
||||
"data_mapping.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"print(f\"Updated data saved to {output_file_path}\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Updated data saved to ../preprocessed_data.csv\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"# Load the data_mapping CSV file\n",
|
||||
"data_mapping_file_path = 'raw_data_s.csv' # Adjust this path to your actual file location\n",
|
||||
"data_mapping = pd.read_csv(data_mapping_file_path, dtype=str)\n",
|
||||
" \n",
|
||||
" # Ensure all values in the 'tag_description' column are strings\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].fillna('').astype(str)\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].str.replace(r'[-]', ' ', regex=True)\n",
|
||||
"\n",
|
||||
"# Initial replacement mapping\n",
|
||||
"initial_replacements = {\n",
|
||||
" \"MGE\": \"G/E\",\n",
|
||||
" \"GEN.\": \"G/E\",\n",
|
||||
" \"GEN\": \"G/E\",\n",
|
||||
" \"GE\": \"G/E\",\n",
|
||||
" \"G_E\": \"G/E\",\n",
|
||||
" \"ME\": \"M/E\",\n",
|
||||
" \"M_E\": \"M/E\",\n",
|
||||
" \"S_G\": \"S/G\",\n",
|
||||
" \"T_C\": \"T/C\",\n",
|
||||
" \"TC\": \"T/C\",\n",
|
||||
" \"L_O\": \"L.O\",\n",
|
||||
" \"LO\": \"L.O\",\n",
|
||||
" \"F_O\": \"F.O\",\n",
|
||||
" \"FO\": \"F.O\",\n",
|
||||
" \"D_G\": \"D/G\",\n",
|
||||
" \"DG\": \"D/G\",\n",
|
||||
" \"PP\": \"P/P\"\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Second replacement mapping\n",
|
||||
"second_replacements = {\n",
|
||||
" \"_G/E\": \" G/E\",\n",
|
||||
" \"G/E_\": \"G/E \",\n",
|
||||
" \"_M/E\": \" M/E\",\n",
|
||||
" \"M/E_\": \"M/E \",\n",
|
||||
" \"_S/G\": \" S/G\",\n",
|
||||
" \"S/G_\": \"S/G \",\n",
|
||||
" \"_T/C\": \" T/C\",\n",
|
||||
" \"T/C_\": \"T/C \",\n",
|
||||
" \"_L.O\": \" L.O\",\n",
|
||||
" \"L.O_\": \"L.O \",\n",
|
||||
" \"_F.O\": \" F.O\",\n",
|
||||
" \"F.O_\": \"F.O \",\n",
|
||||
" \"_D/G\": \" D/G\",\n",
|
||||
" \"D/G_\": \"D/G \",\n",
|
||||
" \"DG_\": \"DG \"\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Function to separate numbers from text in a token\n",
|
||||
"def separate_numbers_from_text(description):\n",
|
||||
" # This regex pattern finds occurrences where text is followed by numbers or vice versa\n",
|
||||
" return re.sub(r'(\\d+)(\\D)', r'\\1 \\2', re.sub(r'(\\D)(\\d+)', r'\\1 \\2', description))\n",
|
||||
"\n",
|
||||
"# Function to perform replacements using tokens\n",
|
||||
"def replace_tokens(description, replacements):\n",
|
||||
" tokens = description.split() # Tokenize by spaces\n",
|
||||
" tokens = [replacements.get(token, token) for token in tokens] # Replace based on the dictionary\n",
|
||||
" return ' '.join(tokens)\n",
|
||||
"\n",
|
||||
"# Function to perform replacements for substrings\n",
|
||||
"def replace_substrings(description, replacements):\n",
|
||||
" for old, new in replacements.items():\n",
|
||||
" description = description.replace(old, new)\n",
|
||||
" return description\n",
|
||||
"\n",
|
||||
"# Separate numbers from text before applying replacements\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].apply(separate_numbers_from_text)\n",
|
||||
"\n",
|
||||
"# Apply initial replacements\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].apply(replace_tokens, replacements=initial_replacements)\n",
|
||||
"\n",
|
||||
"# Apply second replacements as substrings\n",
|
||||
"data_mapping['tag_description'] = data_mapping['tag_description'].apply(replace_substrings, replacements=second_replacements)\n",
|
||||
"\n",
|
||||
"# Save the updated data_mapping to a new CSV file\n",
|
||||
"output_file_path = '../preprocessed_data.csv'\n",
|
||||
"data_mapping.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"print(f\"Updated data saved to {output_file_path}\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,441 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Final Group Allocation:\n",
|
||||
"Group 1: Ships_idx = [1003, 1028, 1049, 1044, 1020, 1041, 1045, 1036, 1005, 1006], PD type = 537, PD = 2006, SD = 14719\n",
|
||||
"Group 2: Ships_idx = [1025, 1035, 1021, 1026, 1002, 1030, 1024, 1037, 1038, 1029], PD type = 537, PD = 1958, SD = 8173\n",
|
||||
"Group 3: Ships_idx = [1016, 1046, 1031, 1009, 1048, 1043, 1042, 1019, 1018, 1007, 1000], PD type = 534, PD = 2079, SD = 15310\n",
|
||||
"Group 4: Ships_idx = [1004, 1032, 1039, 1014, 1040, 1017, 1022, 1051, 1008, 1050, 1013], PD type = 532, PD = 2066, SD = 12882\n",
|
||||
"Group 5: Ships_idx = [1047, 1015, 1027, 1010, 1011, 1001, 1034, 1023, 1012, 1033], PD type = 531, PD = 2064, SD = 10988\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from collections import defaultdict\n",
|
||||
"\n",
|
||||
"# Function to calculate the number of unique combinations and total count for each ship\n",
|
||||
"def calculate_ship_count(group):\n",
|
||||
" ship_count = group.groupby('ships_idx')['thing_property'].agg(['nunique', 'size']).reset_index()\n",
|
||||
" ship_count.columns = ['ships_idx', 'comb_count', 'total_count']\n",
|
||||
" return ship_count\n",
|
||||
"\n",
|
||||
"# Function to calculate the combination count and total count for a group\n",
|
||||
"def calculate_group_count(group):\n",
|
||||
" comb_count = group['thing_property'].nunique()\n",
|
||||
" total_count = group['thing_property'].size\n",
|
||||
" return comb_count, total_count\n",
|
||||
"\n",
|
||||
"# Function to calculate the increase in combination count when a ship is added to a group\n",
|
||||
"def calculate_comb_count_increase(groups, g, ship_idx, mdm):\n",
|
||||
" temp_groups = defaultdict(list, {k: v.copy() for k, v in groups.items()})\n",
|
||||
" temp_groups[g].append(ship_idx)\n",
|
||||
" \n",
|
||||
" group_ships = temp_groups[g]\n",
|
||||
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
|
||||
" \n",
|
||||
" new_comb_count, _ = calculate_group_count(group_data)\n",
|
||||
" \n",
|
||||
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
|
||||
" current_comb_count, _ = calculate_group_count(current_group_data)\n",
|
||||
" \n",
|
||||
" increase = new_comb_count - current_comb_count\n",
|
||||
" \n",
|
||||
" return increase\n",
|
||||
"\n",
|
||||
"# Function to calculate the increase in total count when a ship is added to a group\n",
|
||||
"def calculate_total_count_increase(groups, g, ship_idx, mdm):\n",
|
||||
" temp_groups = defaultdict(list, {k: v.copy() for k, v in groups.items()})\n",
|
||||
" temp_groups[g].append(ship_idx)\n",
|
||||
" \n",
|
||||
" group_ships = temp_groups[g]\n",
|
||||
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
|
||||
" \n",
|
||||
" _, new_total_count = calculate_group_count(group_data)\n",
|
||||
" \n",
|
||||
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
|
||||
" _, current_total_count = calculate_group_count(current_group_data)\n",
|
||||
" \n",
|
||||
" increase = new_total_count - current_total_count\n",
|
||||
" \n",
|
||||
" return increase\n",
|
||||
"\n",
|
||||
"# Function to find the ship that will bring the total count closest to the target\n",
|
||||
"def find_closest_total_count_ship(groups, g, remaining_ships, mdm, target_total_count):\n",
|
||||
" total_count_differences = []\n",
|
||||
"\n",
|
||||
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
|
||||
" _, current_total_count = calculate_group_count(current_group_data)\n",
|
||||
"\n",
|
||||
" for ship_idx in remaining_ships:\n",
|
||||
" increase = calculate_total_count_increase(groups, g, ship_idx, mdm)\n",
|
||||
" new_total_count = current_total_count + increase\n",
|
||||
" difference = abs(target_total_count - new_total_count)\n",
|
||||
" total_count_differences.append((ship_idx, difference, increase))\n",
|
||||
"\n",
|
||||
" if not total_count_differences:\n",
|
||||
" return None, 0\n",
|
||||
" \n",
|
||||
" closest_ship = min(total_count_differences, key=lambda x: x[1])\n",
|
||||
" selected_ship_idx, _, selected_increase = closest_ship\n",
|
||||
"\n",
|
||||
" return selected_ship_idx, selected_increase\n",
|
||||
"\n",
|
||||
"# Function to find the ship that gives the maximum increase in combination count\n",
|
||||
"def find_max_increase_ship(groups, g, remaining_ships, mdm):\n",
|
||||
" comb_count_increase = []\n",
|
||||
"\n",
|
||||
" for ship_idx in remaining_ships:\n",
|
||||
" increase = calculate_comb_count_increase(groups, g, ship_idx, mdm)\n",
|
||||
" comb_count_increase.append((ship_idx, increase))\n",
|
||||
"\n",
|
||||
" max_increase_ship = max(comb_count_increase, key=lambda x: x[1])\n",
|
||||
" selected_ship_idx, max_increase = max_increase_ship\n",
|
||||
" \n",
|
||||
" return selected_ship_idx, max_increase\n",
|
||||
"\n",
|
||||
"# Function to find the ship that will bring the combination count closest to the target\n",
|
||||
"def find_closest_comb_count_ship(groups, g, remaining_ships, mdm, target_comb_count):\n",
|
||||
" comb_count_differences = []\n",
|
||||
"\n",
|
||||
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
|
||||
" current_comb_count, _ = calculate_group_count(current_group_data)\n",
|
||||
"\n",
|
||||
" for ship_idx in remaining_ships:\n",
|
||||
" increase = calculate_comb_count_increase(groups, g, ship_idx, mdm)\n",
|
||||
" new_comb_count = current_comb_count + increase\n",
|
||||
" difference = abs(target_comb_count - new_comb_count)\n",
|
||||
" comb_count_differences.append((ship_idx, difference, increase))\n",
|
||||
"\n",
|
||||
" if not comb_count_differences:\n",
|
||||
" return None, 0\n",
|
||||
"\n",
|
||||
" closest_ship = min(comb_count_differences, key=lambda x: x[1])\n",
|
||||
" selected_ship_idx, _, selected_increase = closest_ship\n",
|
||||
"\n",
|
||||
" return selected_ship_idx, selected_increase\n",
|
||||
"\n",
|
||||
"# Function to find the group with the maximum combination count\n",
|
||||
"def find_group_with_max_comb_count(groups, mdm):\n",
|
||||
" max_comb_count = -1\n",
|
||||
" max_group_idx = -1\n",
|
||||
"\n",
|
||||
" for g in range(len(groups)):\n",
|
||||
" group_ships = groups[g]\n",
|
||||
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
|
||||
" comb_count, _ = calculate_group_count(group_data)\n",
|
||||
" \n",
|
||||
" if comb_count > max_comb_count:\n",
|
||||
" max_comb_count = comb_count\n",
|
||||
" max_group_idx = g\n",
|
||||
"\n",
|
||||
" return max_group_idx, max_comb_count\n",
|
||||
"\n",
|
||||
"# Function to find the group with the maximum total count\n",
|
||||
"def find_group_with_max_total_count(groups, mdm):\n",
|
||||
" max_total_count = -1\n",
|
||||
" max_group_idx = -1\n",
|
||||
"\n",
|
||||
" for g in range(len(groups)):\n",
|
||||
" group_ships = groups[g]\n",
|
||||
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
|
||||
" _, total_count = calculate_group_count(group_data)\n",
|
||||
" \n",
|
||||
" if total_count > max_total_count:\n",
|
||||
" max_total_count = total_count\n",
|
||||
" max_group_idx = g\n",
|
||||
"\n",
|
||||
" return max_group_idx, max_total_count\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"from collections import defaultdict\n",
|
||||
"\n",
|
||||
"# Load the CSV file\n",
|
||||
"data_file_path = 'preprocessed_data.csv'\n",
|
||||
"data = pd.read_csv(data_file_path)\n",
|
||||
"\n",
|
||||
"# Filter the data where MDM is True\n",
|
||||
"mdm_true = data[data['MDM'] == True].copy() # .copy()를 사용하여 명시적으로 복사본 생성\n",
|
||||
"mdm_all = data.copy()\n",
|
||||
"\n",
|
||||
"# Create a new column combining 'thing' and 'property'\n",
|
||||
"mdm_true.loc[:, 'thing_property'] = mdm_true['thing'] + '_' + mdm_true['property']\n",
|
||||
"mdm_all.loc[:, 'thing_property'] = mdm_all['thing'] + '_' + mdm_all['property']\n",
|
||||
"\n",
|
||||
"# Initial setup for groups\n",
|
||||
"ship_count = calculate_ship_count(mdm_true)\n",
|
||||
"num_groups = 5\n",
|
||||
"groups = defaultdict(list)\n",
|
||||
"\n",
|
||||
"# Sort ships by combination count in descending order\n",
|
||||
"sorted_ships = ship_count.sort_values(by='comb_count', ascending=False)\n",
|
||||
"\n",
|
||||
"# Assign the first 5 ships to the groups\n",
|
||||
"for i in range(num_groups):\n",
|
||||
" groups[i].append(sorted_ships.iloc[i]['ships_idx'])\n",
|
||||
"\n",
|
||||
"remaining_ships = sorted_ships.iloc[num_groups:]['ships_idx'].values\n",
|
||||
"\n",
|
||||
"# Allocate remaining ships to the groups\n",
|
||||
"while len(remaining_ships) > 0:\n",
|
||||
" group_comb_counts = []\n",
|
||||
" for g in range(num_groups):\n",
|
||||
" group_ships = groups[g]\n",
|
||||
" group_data = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
|
||||
" comb_count, _ = calculate_group_count(group_data)\n",
|
||||
" group_comb_counts.append((g, comb_count))\n",
|
||||
"\n",
|
||||
" group_comb_counts.sort(key=lambda x: x[1])\n",
|
||||
" \n",
|
||||
" remaining_group = []\n",
|
||||
" for g, _ in group_comb_counts:\n",
|
||||
" if len(remaining_ships) == 0:\n",
|
||||
" break\n",
|
||||
" \n",
|
||||
" if group_comb_counts.index((g, _)) == 0:\n",
|
||||
" selected_ship_idx, comb_increase = find_max_increase_ship(groups, g, remaining_ships, mdm_true)\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" max_group_idx, max_comb_count = find_group_with_max_comb_count(groups, mdm_true)\n",
|
||||
" selected_ship_idx, comb_increase = find_closest_comb_count_ship(groups, g, remaining_ships, mdm_true, max_comb_count)\n",
|
||||
"\n",
|
||||
" if comb_increase == 0:\n",
|
||||
" remaining_group.append(g)\n",
|
||||
" else:\n",
|
||||
" groups[g].append(selected_ship_idx)\n",
|
||||
" remaining_ships = remaining_ships[remaining_ships != selected_ship_idx]\n",
|
||||
"\n",
|
||||
" for g in remaining_group:\n",
|
||||
" if len(remaining_ships) == 0:\n",
|
||||
" break\n",
|
||||
" max_group_idx, max_total_count = find_group_with_max_total_count(groups, mdm_true)\n",
|
||||
" selected_ship_idx, count_increase = find_closest_total_count_ship(groups, g, remaining_ships, mdm_true, max_total_count)\n",
|
||||
" if selected_ship_idx is not None:\n",
|
||||
" groups[g].append(selected_ship_idx)\n",
|
||||
" remaining_ships = remaining_ships[remaining_ships != selected_ship_idx]\n",
|
||||
"\n",
|
||||
"# Calculate comb_count for each group and store it in a list\n",
|
||||
"group_comb_counts = []\n",
|
||||
"for g in range(num_groups):\n",
|
||||
" group_ships = groups[g]\n",
|
||||
" group_data_true = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
|
||||
" comb_count, total_count = calculate_group_count(group_data_true)\n",
|
||||
"\n",
|
||||
" # Calculate total count including MDM=False\n",
|
||||
" group_data_all = mdm_all[mdm_all['ships_idx'].isin(group_ships)]\n",
|
||||
" _, total_count_all = calculate_group_count(group_data_all)\n",
|
||||
" \n",
|
||||
" group_comb_counts.append((g, comb_count, total_count_all))\n",
|
||||
"\n",
|
||||
"# Sort the groups by comb_count in descending order\n",
|
||||
"group_comb_counts.sort(key=lambda x: x[1], reverse=True)\n",
|
||||
"\n",
|
||||
"# Reorder the groups dictionary based on the sorted order\n",
|
||||
"sorted_groups = defaultdict(list)\n",
|
||||
"for i, (g, _, _) in enumerate(group_comb_counts):\n",
|
||||
" sorted_groups[i] = groups[g]\n",
|
||||
"\n",
|
||||
"# Final output of group allocation\n",
|
||||
"print(\"Final Group Allocation:\")\n",
|
||||
"for g in range(num_groups):\n",
|
||||
" group_ships = sorted_groups[g]\n",
|
||||
" group_data_true = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
|
||||
" comb_count, total_count = calculate_group_count(group_data_true)\n",
|
||||
"\n",
|
||||
" # Calculate total count including MDM=False\n",
|
||||
" group_data_all = mdm_all[mdm_all['ships_idx'].isin(group_ships)]\n",
|
||||
" _, total_count_all = calculate_group_count(group_data_all)\n",
|
||||
"\n",
|
||||
" print(f\"Group {g + 1}: Ships_idx = {group_ships}, PD type = {comb_count}, PD = {total_count}, SD = {total_count_all}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CSV file has been generated: 'combined_group_allocation.csv'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.model_selection import GroupKFold\n",
|
||||
"\n",
|
||||
"# Prepare data for custom group allocation (BGKF)\n",
|
||||
"comb_counts = []\n",
|
||||
"total_counts = []\n",
|
||||
"ship_counts = []\n",
|
||||
"custom_results = []\n",
|
||||
"\n",
|
||||
"for g in range(num_groups):\n",
|
||||
" group_ships = groups[g]\n",
|
||||
" group_data_true = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
|
||||
" comb_count, total_count = calculate_group_count(group_data_true)\n",
|
||||
" \n",
|
||||
" # Calculate total count including MDM=False\n",
|
||||
" group_data_all = mdm_all[mdm_all['ships_idx'].isin(group_ships)]\n",
|
||||
" _, total_count_all = calculate_group_count(group_data_all)\n",
|
||||
" \n",
|
||||
" custom_results.append({\n",
|
||||
" 'Group': g + 1,\n",
|
||||
" 'Allocation': 'BGKF',\n",
|
||||
" 'Comb_count': comb_count,\n",
|
||||
" 'Total_count': total_count,\n",
|
||||
" 'Total_count_all': total_count_all,\n",
|
||||
" 'Ship_count': len(group_ships),\n",
|
||||
" 'Ships_idx': list(group_ships)\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"# Sort the custom group allocation by comb_count in descending order\n",
|
||||
"custom_results.sort(key=lambda x: x['Comb_count'], reverse=True)\n",
|
||||
"\n",
|
||||
"# Adjust group numbers after sorting\n",
|
||||
"for i, result in enumerate(custom_results):\n",
|
||||
" result['Group'] = i + 1\n",
|
||||
"\n",
|
||||
"# Prepare data for GroupKFold allocation (GKF)\n",
|
||||
"gkf = GroupKFold(n_splits=5)\n",
|
||||
"gkf_results = []\n",
|
||||
"\n",
|
||||
"for i, (train_idx, test_idx) in enumerate(gkf.split(mdm_true, groups=mdm_true['ships_idx'])):\n",
|
||||
" test_group = mdm_true.iloc[test_idx]\n",
|
||||
" comb_count, total_count = calculate_group_count(test_group)\n",
|
||||
" \n",
|
||||
" # Calculate total count including MDM=False\n",
|
||||
" test_group_ships = test_group['ships_idx'].unique()\n",
|
||||
" test_group_all = mdm_all[mdm_all['ships_idx'].isin(test_group_ships)]\n",
|
||||
" _, total_count_all = calculate_group_count(test_group_all)\n",
|
||||
" \n",
|
||||
" gkf_results.append({\n",
|
||||
" 'Group': i + 1,\n",
|
||||
" 'Allocation': 'GKF',\n",
|
||||
" 'Comb_count': comb_count,\n",
|
||||
" 'Total_count': total_count,\n",
|
||||
" 'Total_count_all': total_count_all,\n",
|
||||
" 'Ship_count': test_group['ships_idx'].nunique(),\n",
|
||||
" 'Ships_idx': list(test_group['ships_idx'].unique())\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"# Sort the GKF allocation by comb_count in descending order\n",
|
||||
"gkf_results.sort(key=lambda x: x['Comb_count'], reverse=True)\n",
|
||||
"\n",
|
||||
"# Adjust group numbers after sorting\n",
|
||||
"for i, result in enumerate(gkf_results):\n",
|
||||
" result['Group'] = i + 1\n",
|
||||
"\n",
|
||||
"# Combine BGKF and GKF results into one DataFrame\n",
|
||||
"combined_results = custom_results + gkf_results\n",
|
||||
"combined_df = pd.DataFrame(combined_results)\n",
|
||||
"\n",
|
||||
"# Output the combined results to a single CSV file\n",
|
||||
"combined_df.to_csv('combined_group_allocation.csv', index=False)\n",
|
||||
"\n",
|
||||
"print(\"CSV file has been generated: 'combined_group_allocation.csv'\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Group 1 datasets saved in dataset/1\n",
|
||||
"Group 2 datasets saved in dataset/2\n",
|
||||
"Group 3 datasets saved in dataset/3\n",
|
||||
"Group 4 datasets saved in dataset/4\n",
|
||||
"Group 5 datasets saved in dataset/5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.model_selection import KFold\n",
|
||||
"\n",
|
||||
"def save_datasets_for_group(groups, mdm, data, output_dir='dataset', n_splits=4):\n",
|
||||
" for i in range(len(groups)):\n",
|
||||
" group_folder = os.path.join(output_dir, str(i + 1))\n",
|
||||
" os.makedirs(group_folder, exist_ok=True)\n",
|
||||
" \n",
|
||||
" # Create the test dataset by including only group i\n",
|
||||
" test_group_ships = groups[i]\n",
|
||||
" test_data = mdm[mdm['ships_idx'].isin(test_group_ships)]\n",
|
||||
" \n",
|
||||
" # Extract corresponding entries from the external test dataset\n",
|
||||
" test_all_data = data[data['ships_idx'].isin(test_group_ships)]\n",
|
||||
" \n",
|
||||
" # Create the train dataset by excluding group i\n",
|
||||
" train_group_ships = []\n",
|
||||
" for g in range(len(groups)):\n",
|
||||
" if g != i:\n",
|
||||
" train_group_ships.extend(groups[g])\n",
|
||||
" train_data = mdm[mdm['ships_idx'].isin(train_group_ships)]\n",
|
||||
" \n",
|
||||
" # Use KFold to split train_data into train and valid datasets\n",
|
||||
" kf_inner = KFold(n_splits=n_splits, shuffle=True, random_state=42)\n",
|
||||
" train_idx_inner, valid_idx_inner = next(kf_inner.split(train_data))\n",
|
||||
" \n",
|
||||
" final_train_data = train_data.iloc[train_idx_inner]\n",
|
||||
" valid_data = train_data.iloc[valid_idx_inner]\n",
|
||||
" \n",
|
||||
" # Combine train and valid data to create train_all\n",
|
||||
" train_all_data = pd.concat([final_train_data, valid_data])\n",
|
||||
" \n",
|
||||
" # Save datasets to CSV files\n",
|
||||
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
||||
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
||||
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
||||
" test_all_file_path = os.path.join(group_folder, 'test_all.csv')\n",
|
||||
" train_all_file_path = os.path.join(group_folder, 'train_all.csv')\n",
|
||||
" \n",
|
||||
" final_train_data.to_csv(train_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
" valid_data.to_csv(valid_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
" # test_data.to_csv(test_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
" test_all_data.to_csv(test_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
" train_all_data.to_csv(train_all_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
" \n",
|
||||
" print(f\"Group {i + 1} datasets saved in {group_folder}\")\n",
|
||||
"\n",
|
||||
"# Example usage:\n",
|
||||
"save_datasets_for_group(groups, mdm_true, data, n_splits=4)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Performance for all_with_p_s.csv:\n",
|
||||
"TP: 1724, TN: 11907, FP: 919, FN: 272\n",
|
||||
"Precision: 0.6523, Recall: 0.8637, Accuracy: 0.9196\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# Set the group number\n",
|
||||
"group_number = 1 # Change this to the desired group number\n",
|
||||
"\n",
|
||||
"# File paths for the two datasets\n",
|
||||
"test_s_path = f'../post_process/0.result/{group_number}/test_s.csv'\n",
|
||||
"\n",
|
||||
"# Load the CSV files\n",
|
||||
"test_s_csv = pd.read_csv(test_s_path, low_memory=False)\n",
|
||||
"test_s_csv.fillna('', inplace=True)\n",
|
||||
"\n",
|
||||
"def evaluate_performance(test_csv):\n",
|
||||
" # Initialize counters for TP, TN, FP, FN\n",
|
||||
" TP = 0\n",
|
||||
" TN = 0\n",
|
||||
" FP = 0\n",
|
||||
" FN = 0\n",
|
||||
"\n",
|
||||
" # Iterate over the DataFrame rows\n",
|
||||
" for index, row in test_csv.iterrows():\n",
|
||||
" # True Positive (TP): s_correct is True and MDM is True\n",
|
||||
" if row['s_correct'] and row['MDM']:\n",
|
||||
" TP += 1\n",
|
||||
" # True Negative (TN): s_thing is null and MDM is False\n",
|
||||
" elif row['s_thing'] == '' and not row['MDM']:\n",
|
||||
" TN += 1\n",
|
||||
" # False Positive (FP): \n",
|
||||
" # 1) s_thing is not null and MDM is False \n",
|
||||
" # OR \n",
|
||||
" # 2) s_thing is not null and s_correct is False and MDM is True\n",
|
||||
" elif (row['s_thing'] != '' and not row['MDM']) or (row['s_thing'] != '' and not row['s_correct'] and row['MDM']):\n",
|
||||
" FP += 1\n",
|
||||
" # False Negative (FN): s_thing is null and MDM is True\n",
|
||||
" elif row['s_thing'] == '' and row['MDM']:\n",
|
||||
" FN += 1\n",
|
||||
"\n",
|
||||
" # Calculate total\n",
|
||||
" total = TP + TN + FP + FN\n",
|
||||
"\n",
|
||||
" # Calculate Precision, Recall, and Accuracy\n",
|
||||
" precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n",
|
||||
" recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n",
|
||||
" accuracy = (TP + TN) / total if total > 0 else 0\n",
|
||||
"\n",
|
||||
" return TP, TN, FP, FN, precision, recall, accuracy\n",
|
||||
"\n",
|
||||
"# Evaluate both datasets\n",
|
||||
"tp_s_results = evaluate_performance(test_s_csv)\n",
|
||||
"\n",
|
||||
"# Print the results for both datasets\n",
|
||||
"print(\"Performance for all_with_p_s.csv:\")\n",
|
||||
"print(f\"TP: {tp_s_results[0]}, TN: {tp_s_results[1]}, FP: {tp_s_results[2]}, FN: {tp_s_results[3]}\")\n",
|
||||
"print(f\"Precision: {tp_s_results[4]:.4f}, Recall: {tp_s_results[5]:.4f}, Accuracy: {tp_s_results[6]:.4f}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
import pandas as pd
|
||||
import re
|
||||
import os
|
||||
|
||||
# Loop through group numbers from 1 to 5
|
||||
for group_number in range(1, 6):
|
||||
|
||||
# Path to the train_all file
|
||||
train_all_path = f'data_preprocess/dataset/{group_number}/train_all.csv'
|
||||
|
||||
# Read the train_all data
|
||||
train_all_csv = pd.read_csv(train_all_path, low_memory=False)
|
||||
|
||||
# Concatenate tag_description based on the combination of thing and property
|
||||
tag_description_concatenated = train_all_csv.groupby(['thing', 'property'])['tag_description'].apply(lambda x: ' '.join(x)).reset_index()
|
||||
|
||||
# Concatenate tag_name based on the combination of thing and property
|
||||
tag_name_concatenated = train_all_csv.groupby(['thing', 'property'])['tag_name'].apply(lambda x: ' '.join(x)).reset_index()
|
||||
|
||||
# Calculate mapping_count
|
||||
mapping_count = train_all_csv.groupby(['thing', 'property']).size().reset_index(name='mapping_count')
|
||||
|
||||
# Merge the three DataFrames: mapping_count, tag_description_concatenated, and tag_name_concatenated
|
||||
thing_property_grouped = pd.merge(mapping_count, tag_description_concatenated, on=['thing', 'property'])
|
||||
thing_property_grouped = pd.merge(thing_property_grouped, tag_name_concatenated, on=['thing', 'property'])
|
||||
|
||||
# Calculate token_count by splitting tag_description using r'\S+'
|
||||
thing_property_grouped['td_token_count'] = thing_property_grouped['tag_description'].apply(lambda x: len(re.findall(r'\S+', x)))
|
||||
|
||||
# Create pattern by replacing digits in 'thing' and 'property' with '#'
|
||||
thing_property_grouped['pattern'] = thing_property_grouped['thing'].str.replace(r'\d', '#', regex=True) + " " + thing_property_grouped['property'].str.replace(r'\d', '#', regex=True)
|
||||
|
||||
# Calculate the total number of unique thing_property combinations
|
||||
total_thing_property_count = thing_property_grouped.shape[0]
|
||||
|
||||
# Specify the output path
|
||||
output_path = f'post_process/tfidf_class/0.class_document/{group_number}/sdl_class_rdoc.csv'
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
output_dir = os.path.dirname(output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save the result to the CSV file
|
||||
thing_property_grouped.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
print(f"Concatenated data saved to {output_path}")
|
||||
print(f"Total number of unique thing_property combinations: {total_thing_property_count}")
|
|
@ -0,0 +1,134 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy (MDM=True) for Group 1: 79.41%\n",
|
||||
"Accuracy (MDM=True) for Group 2: 79.32%\n",
|
||||
"Accuracy (MDM=True) for Group 3: 82.49%\n",
|
||||
"Accuracy (MDM=True) for Group 4: 85.61%\n",
|
||||
"Accuracy (MDM=True) for Group 5: 79.72%\n",
|
||||
"Average Accuracy (MDM=True) across all groups: 81.31%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Initialize a list to store the accuracies for each group\n",
|
||||
"accuracies = []\n",
|
||||
"\n",
|
||||
"# Loop through group numbers from 1 to 5\n",
|
||||
"for group_number in range(1, 6):\n",
|
||||
" \n",
|
||||
" # Load the CSV files from the specified group\n",
|
||||
" sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
|
||||
" test_path = f'../../data_preprocess/dataset/{group_number}/test.csv'\n",
|
||||
" \n",
|
||||
" # Check if test file exists, if not, skip this iteration\n",
|
||||
" if not os.path.exists(test_path):\n",
|
||||
" print(f\"test file for Group {group_number} does not exist. Skipping...\")\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
|
||||
" test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||
" \n",
|
||||
" # Replace NaN values with empty strings in relevant columns\n",
|
||||
" sdl_class_rdoc_csv['tag_description'] = sdl_class_rdoc_csv['tag_description'].fillna('')\n",
|
||||
" test_csv['tag_description'] = test_csv['tag_description'].fillna('')\n",
|
||||
" \n",
|
||||
" # Initialize new columns in test_csv\n",
|
||||
" test_csv['c_thing'] = ''\n",
|
||||
" test_csv['c_property'] = ''\n",
|
||||
" test_csv['c_score'] = ''\n",
|
||||
" test_csv['c_duplicate'] = 0 # Initialize c_duplicate to store duplicate counts\n",
|
||||
" \n",
|
||||
" # Combine both sdl_class_rdoc and test CSVs tag_descriptions for TF-IDF Vectorizer training\n",
|
||||
" combined_tag_descriptions = sdl_class_rdoc_csv['tag_description'].tolist() + test_csv['tag_description'].tolist()\n",
|
||||
" \n",
|
||||
" # Create a TF-IDF Vectorizer\n",
|
||||
" vectorizer = TfidfVectorizer(\n",
|
||||
" token_pattern=r'\\S+',\n",
|
||||
" ngram_range=(1, 6), # Use ngrams from 1 to 6\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Fit the TF-IDF vectorizer on the combined tag_descriptions\n",
|
||||
" vectorizer.fit(combined_tag_descriptions)\n",
|
||||
" \n",
|
||||
" # Transform both sdl_class_rdoc and test CSVs into TF-IDF matrices\n",
|
||||
" sdl_class_rdoc_tfidf_matrix = vectorizer.transform(sdl_class_rdoc_csv['tag_description'])\n",
|
||||
" test_tfidf_matrix = vectorizer.transform(test_csv['tag_description'])\n",
|
||||
" \n",
|
||||
" # Calculate cosine similarity between test and class-level sdl_class_rdoc vectors\n",
|
||||
" similarity_matrix = cosine_similarity(test_tfidf_matrix, sdl_class_rdoc_tfidf_matrix)\n",
|
||||
" \n",
|
||||
" # Find the most similar class-level tag_description for each test description\n",
|
||||
" most_similar_indices = similarity_matrix.argmax(axis=1)\n",
|
||||
" most_similar_scores = similarity_matrix.max(axis=1)\n",
|
||||
" \n",
|
||||
" # Assign the corresponding thing, property, and similarity score to the test CSV\n",
|
||||
" test_csv['c_thing'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['thing'].values\n",
|
||||
" test_csv['c_property'] = sdl_class_rdoc_csv.iloc[most_similar_indices]['property'].values\n",
|
||||
" test_csv['c_score'] = most_similar_scores\n",
|
||||
" \n",
|
||||
" # Check if the predicted 'c_thing' and 'c_property' match the actual 'thing' and 'property'\n",
|
||||
" test_csv['cthing_correct'] = test_csv['thing'] == test_csv['c_thing']\n",
|
||||
" test_csv['cproperty_correct'] = test_csv['property'] == test_csv['c_property']\n",
|
||||
" test_csv['ctp_correct'] = test_csv['cthing_correct'] & test_csv['cproperty_correct']\n",
|
||||
" \n",
|
||||
" # Calculate accuracy based only on MDM = True\n",
|
||||
" mdm_true_count = len(test_csv[test_csv['MDM'] == True])\n",
|
||||
" accuracy = (test_csv['ctp_correct'].sum() / mdm_true_count) * 100\n",
|
||||
" accuracies.append(accuracy)\n",
|
||||
" \n",
|
||||
" print(f\"Accuracy (MDM=True) for Group {group_number}: {accuracy:.2f}%\")\n",
|
||||
" \n",
|
||||
" # Specify output file paths\n",
|
||||
" output_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
||||
" test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||||
" \n",
|
||||
" # Filter for rows where MDM is True and ctp_correct is False\n",
|
||||
" false_positive_rows = test_csv[(test_csv['MDM'] == True) & (test_csv['ctp_correct'] == False)]\n",
|
||||
" \n",
|
||||
" # Save false positives to a separate file\n",
|
||||
" fp_output_path = f'0.class_document/{group_number}/fp_class.csv'\n",
|
||||
" false_positive_rows.to_csv(fp_output_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"# Calculate and print the average accuracy across all groups\n",
|
||||
"average_accuracy = sum(accuracies) / len(accuracies)\n",
|
||||
"print(f\"Average Accuracy (MDM=True) across all groups: {average_accuracy:.2f}%\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "KeyError",
|
||||
"evalue": "'p_correct'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
||||
"File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
||||
"File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
||||
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
||||
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;31mKeyError\u001b[0m: 'p_correct'",
|
||||
"\nThe above exception was the direct cause of the following exception:\n",
|
||||
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[11], line 22\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m index, row \u001b[38;5;129;01min\u001b[39;00m test_csv\u001b[38;5;241m.\u001b[39miterrows():\n\u001b[0;32m---> 22\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mrow\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mp_correct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;129;01mand\u001b[39;00m row[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mctp_correct\u001b[39m\u001b[38;5;124m'\u001b[39m]:\n\u001b[1;32m 23\u001b[0m update_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# Increment the counter\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# Check for duplicates within the same ships_idx\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/series.py:1121\u001b[0m, in \u001b[0;36mSeries.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_values[key]\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m key_is_scalar:\n\u001b[0;32m-> 1121\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;66;03m# Convert generator to list before going through hashable part\u001b[39;00m\n\u001b[1;32m 1124\u001b[0m \u001b[38;5;66;03m# (We will iterate through the generator there to check for slices)\u001b[39;00m\n\u001b[1;32m 1125\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_iterator(key):\n",
|
||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/series.py:1237\u001b[0m, in \u001b[0;36mSeries._get_value\u001b[0;34m(self, label, takeable)\u001b[0m\n\u001b[1;32m 1234\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_values[label]\n\u001b[1;32m 1236\u001b[0m \u001b[38;5;66;03m# Similar to Index.get_value, but we do not fall back to positional\u001b[39;00m\n\u001b[0;32m-> 1237\u001b[0m loc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1239\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(loc):\n\u001b[1;32m 1240\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_values[loc]\n",
|
||||
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3807\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m 3810\u001b[0m ):\n\u001b[1;32m 3811\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 3814\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3815\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3816\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3817\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n",
|
||||
"\u001b[0;31mKeyError\u001b[0m: 'p_correct'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"# Set the group number\n",
|
||||
"group_number = 1 # Change this to the desired group number\n",
|
||||
"\n",
|
||||
"# Load the CSV files from the specified group\n",
|
||||
"sdl_class_rdoc_path = f'0.class_document/{group_number}/sdl_class_rdoc.csv'\n",
|
||||
"test_path = f'0.class_document/{group_number}/test_p_c.csv'\n",
|
||||
"\n",
|
||||
"sdl_class_rdoc_csv = pd.read_csv(sdl_class_rdoc_path, low_memory=False)\n",
|
||||
"test_csv = pd.read_csv(test_path, low_memory=False)\n",
|
||||
"\n",
|
||||
"update_count = 0\n",
|
||||
"duplicate_count = 0\n",
|
||||
"non_duplicate_count = 0\n",
|
||||
"\n",
|
||||
"# Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
||||
"for index, row in test_csv.iterrows():\n",
|
||||
" if not row['p_correct'] and row['ctp_correct']:\n",
|
||||
" update_count += 1 # Increment the counter\n",
|
||||
"\n",
|
||||
" # Check for duplicates within the same ships_idx\n",
|
||||
" same_idx_rows = test_csv[(test_csv['ships_idx'] == row['ships_idx']) &\n",
|
||||
" (test_csv['p_thing'] == row['c_thing']) &\n",
|
||||
" (test_csv['p_property'] == row['c_property'])]\n",
|
||||
"\n",
|
||||
" if len(same_idx_rows) > 0:\n",
|
||||
" duplicate_count += 1\n",
|
||||
" else:\n",
|
||||
" non_duplicate_count += 1\n",
|
||||
"\n",
|
||||
"# Print the results\n",
|
||||
"print(f\"Total updates where p_correct is False and ctp_correct is True: {update_count}\")\n",
|
||||
"print(f\"Number of rows with duplicates in the same ships_idx: {duplicate_count}\")\n",
|
||||
"print(f\"Number of rows without duplicates in the same ships_idx: {non_duplicate_count}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of updates made: 45\n",
|
||||
"Updated test CSV saved to 0.class_document/1/test_p_c_r.csv\n",
|
||||
"Refine CSV saved to refine.csv\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"update_count = 0\n",
|
||||
"\n",
|
||||
"# Initialize a list to hold rows that meet the conditions\n",
|
||||
"refine_rows = []\n",
|
||||
"\n",
|
||||
"# Assign c_thing, c_property to p_thing, p_property and set p_MDM to True if conditions are met\n",
|
||||
"for index, row in test_csv.iterrows():\n",
|
||||
" if (not row['p_MDM'] and row['c_score'] >= 0.9 and \n",
|
||||
" (row['p_thing'] != row['c_thing'] or row['p_property'] != row['c_property'])):\n",
|
||||
" test_csv.at[index, 'p_thing'] = row['c_thing']\n",
|
||||
" test_csv.at[index, 'p_property'] = row['c_property']\n",
|
||||
" test_csv.at[index, 'p_MDM'] = True\n",
|
||||
" update_count += 1 # Increment the counter\n",
|
||||
" refine_rows.append(row) # Add the row to the refine list\n",
|
||||
"\n",
|
||||
"# Convert the list of refine rows into a DataFrame\n",
|
||||
"refine_df = pd.DataFrame(refine_rows)\n",
|
||||
"\n",
|
||||
"# Save the refine DataFrame to a CSV file\n",
|
||||
"refine_output_path = f'refine.csv'\n",
|
||||
"refine_df.to_csv(refine_output_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"# Print the number of updates made\n",
|
||||
"print(f\"Number of updates made: {update_count}\")\n",
|
||||
"\n",
|
||||
"# Save the updated test CSV\n",
|
||||
"output_file_path = f'0.class_document/{group_number}/test_p_c_r.csv'\n",
|
||||
"test_csv.to_csv(output_file_path, index=False, encoding='utf-8-sig')\n",
|
||||
" \n",
|
||||
"print(f\"Updated test CSV saved to {output_file_path}\")\n",
|
||||
"print(f\"Refine CSV saved to {refine_output_path}\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torch",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
import pandas as pd
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
group_number = 1
|
||||
# Load the CSV files
|
||||
test_path = f'post_process/tfidf_class/0.class_document/{group_number}/test_p_c.csv'
|
||||
test_path = f'post_process/tfidf_class/0.class_document/{group_number}/test_p_c_r.csv'
|
||||
ship_data_list_reference_doc_file_path = f'post_process/tfidf_class/0.class_document/{group_number}/sdl_class_rdoc.csv'
|
||||
|
||||
test_csv = pd.read_csv(test_path, low_memory=False)
|
||||
sdl_rdoc = pd.read_csv(ship_data_list_reference_doc_file_path)
|
||||
|
||||
# Initialize new columns in test_csv
|
||||
test_csv['s_score'] = -1
|
||||
test_csv['s_thing'] = ''
|
||||
test_csv['s_property'] = ''
|
||||
test_csv['s_correct'] = False
|
||||
|
||||
duplicate_filtered = test_csv[(test_csv['p_MDM'] == True)].copy()
|
||||
|
||||
# Create a mapping from thing/property to reference_doc
|
||||
thing_property_to_reference_doc = sdl_rdoc.set_index(['thing', 'property'])['tag_description'].to_dict()
|
||||
|
||||
# Calculate s_score for duplicate rows
|
||||
for ships_idx, group in tqdm(duplicate_filtered.groupby('ships_idx'), desc="Processing duplicates"):
|
||||
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||
sub_group = sub_group.copy()
|
||||
tag_descriptions = sub_group['tag_description'].tolist()
|
||||
|
||||
# Get the reference document for the corresponding p_thing and p_property
|
||||
reference_doc = thing_property_to_reference_doc.get((p_thing, p_property), '')
|
||||
|
||||
if reference_doc:
|
||||
# Combine the tag_descriptions and the reference_doc for fit_transform
|
||||
combined_descriptions = tag_descriptions + [reference_doc]
|
||||
|
||||
# Create a new TF-IDF Vectorizer for this specific group
|
||||
vectorizer = TfidfVectorizer(
|
||||
token_pattern=r'\S+',
|
||||
norm='l2', # Use L2 normalization
|
||||
ngram_range=(1, 7), # Use both unigrams and bigrams
|
||||
)
|
||||
|
||||
# Fit and transform the combined descriptions
|
||||
tfidf_matrix = vectorizer.fit_transform(combined_descriptions)
|
||||
|
||||
# Separate the test_tfidf_matrix and reference_vector
|
||||
test_tfidf_matrix = tfidf_matrix[:-1] # All but the last one
|
||||
reference_vector = tfidf_matrix[-1] # The last one
|
||||
|
||||
# Calculate the cosine similarity between the test descriptions and the reference_doc
|
||||
sub_group['s_score'] = cosine_similarity(test_tfidf_matrix, reference_vector).flatten()
|
||||
else:
|
||||
sub_group['s_score'] = 0
|
||||
|
||||
# Update the s_score values back into the original test_csv
|
||||
duplicate_filtered.loc[sub_group.index, 's_score'] = sub_group['s_score']
|
||||
|
||||
for ships_idx, group in tqdm(duplicate_filtered.groupby('ships_idx'), desc="Processing duplicates"):
|
||||
for (p_thing, p_property), sub_group in group.groupby(['p_thing', 'p_property']):
|
||||
if (sub_group['s_score'] == -1).any():
|
||||
best_index = sub_group.index.min()
|
||||
else:
|
||||
# Find the index of the row with the highest s_score
|
||||
best_index = sub_group['s_score'].idxmax()
|
||||
row_position = sub_group.index.get_loc(best_index)
|
||||
|
||||
# Assign s_thing and s_property only to the row with the highest s_score
|
||||
duplicate_filtered.at[best_index, 's_thing'] = sub_group.at[best_index, 'p_thing']
|
||||
duplicate_filtered.at[best_index, 's_property'] = sub_group.at[best_index, 'p_property']
|
||||
|
||||
# Now, update the original test_csv with the changes made in duplicate_filtered
|
||||
test_csv.update(duplicate_filtered[['s_thing', 's_property', 's_score']])
|
||||
|
||||
# Calculate s_correct
|
||||
test_csv['s_correct'] = ((test_csv['thing'] == test_csv['s_thing']) &
|
||||
(test_csv['property'] == test_csv['s_property']) &
|
||||
(test_csv['MDM']))
|
||||
|
||||
# Calculate the percentage of correct s_thing and s_property
|
||||
mdm_true_count = test_csv['MDM'].sum()
|
||||
s_correct_count = test_csv['s_correct'].sum()
|
||||
s_correct_percentage = (s_correct_count / mdm_true_count) * 100
|
||||
|
||||
print(f"s_correct count: {s_correct_count}")
|
||||
print(f"MDM true count: {mdm_true_count}")
|
||||
print(f"s_correct percentage: {s_correct_percentage:.2f}%")
|
||||
|
||||
|
||||
# Save the updated DataFrame to a new CSV file
|
||||
output_path = test_path = f'post_process/0.result/{group_number}/test_s.csv'
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
test_csv.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
print(f"Updated data saved to {output_path}")
|
||||
|
||||
# Check for duplicates in s_thing and s_property within each ships_idx
|
||||
print("\nShips_idx with duplicate s_thing and s_property:")
|
||||
duplicate_ships_idx = []
|
||||
|
||||
for ships_idx, group in test_csv.groupby('ships_idx'):
|
||||
# Exclude rows with empty s_thing or s_property
|
||||
non_empty_group = group[(group['s_thing'] != '') & (group['s_property'] != '')]
|
||||
duplicate_entries = non_empty_group[non_empty_group.duplicated(subset=['s_thing', 's_property'], keep=False)]
|
||||
if not duplicate_entries.empty:
|
||||
duplicate_ships_idx.append(ships_idx)
|
||||
print(f"Ships_idx: {ships_idx}")
|
||||
print(duplicate_entries[['s_thing', 's_property']])
|
||||
|
||||
if not duplicate_ships_idx:
|
||||
print("No duplicates found.")
|
|
@ -0,0 +1,198 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loaded data for group 1:\n",
|
||||
"Train data shape: (6125, 16)\n",
|
||||
"Valid data shape: (2042, 16)\n",
|
||||
"Test data shape: (14719, 15)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"# Example usage:1\n",
|
||||
"group_number = 1 # You can change this to any group number you want to load (1, 2, 3, 4, or 5)\n",
|
||||
"\n",
|
||||
"# Select the mode for processing\n",
|
||||
"mode = 'tn_td_unit' # Change this to 'only_td', 'tn_td', etc., as needed\n",
|
||||
"\n",
|
||||
"def load_group_data(group_number):\n",
|
||||
" # Define the folder path based on the group number\n",
|
||||
" group_folder = os.path.join('../../data_preprocess/dataset', str(group_number))\n",
|
||||
" \n",
|
||||
" # Define file paths for train, valid, and test datasets\n",
|
||||
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
||||
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
||||
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
||||
" \n",
|
||||
" # Check if the files exist\n",
|
||||
" if not os.path.exists(train_file_path) or not os.path.exists(valid_file_path) or not os.path.exists(test_file_path):\n",
|
||||
" raise FileNotFoundError(f\"One or more files for group {group_number} do not exist.\")\n",
|
||||
" \n",
|
||||
" # Load the CSV files into DataFrames\n",
|
||||
" train_data = pd.read_csv(train_file_path)\n",
|
||||
" valid_data = pd.read_csv(valid_file_path)\n",
|
||||
" test_data = pd.read_csv(test_file_path)\n",
|
||||
" \n",
|
||||
" return train_data, valid_data, test_data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" train_data, valid_data, test_data = load_group_data(group_number)\n",
|
||||
" print(f\"Loaded data for group {group_number}:\")\n",
|
||||
" print(f\"Train data shape: {train_data.shape}\")\n",
|
||||
" print(f\"Valid data shape: {valid_data.shape}\")\n",
|
||||
" print(f\"Test data shape: {test_data.shape}\")\n",
|
||||
"except FileNotFoundError as e:\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "313f98ef12eb442bac319282e5ffe5d6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Saving the dataset (0/1 shards): 0%| | 0/6125 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0c1834a4e7264a969085ad609320fdd6",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Saving the dataset (0/1 shards): 0%| | 0/14719 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "464f88daab334658aac93305ea6dac71",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Saving the dataset (0/1 shards): 0%| | 0/2042 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dataset saved to 'combined_data'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from datasets import Dataset, DatasetDict\n",
|
||||
"\n",
|
||||
"# Function to process DataFrame based on mode\n",
|
||||
"def process_df(df, mode='only_td'):\n",
|
||||
" output_list = []\n",
|
||||
" for idx, row in df.iterrows():\n",
|
||||
" try:\n",
|
||||
" if mode == 'only_td':\n",
|
||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||
" elif mode == 'tn_td':\n",
|
||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||
" elif mode == 'tn_td_min_max':\n",
|
||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||
" elif mode == 'td_min_max':\n",
|
||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\" \n",
|
||||
" elif mode == 'td_unit':\n",
|
||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
||||
" elif mode == 'tn_td_unit':\n",
|
||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Invalid mode specified\")\n",
|
||||
" \n",
|
||||
" output_list.append({\n",
|
||||
" 'translation': {\n",
|
||||
" 'ships_idx': row['ships_idx'],\n",
|
||||
" 'input': input_str,\n",
|
||||
" 'thing_property': f\"<THING_START>{str(row['thing'])}<THING_END><PROPERTY_START>{str(row['property'])}<PROPERTY_END>\",\n",
|
||||
" 'answer': f\"{str(row['thing'])} {str(row['property'])}\",\n",
|
||||
" }\n",
|
||||
" })\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error processing row at index {idx}: {row}\")\n",
|
||||
" print(f\"Exception: {e}\")\n",
|
||||
" return output_list\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Combine the mode and group information into a single dictionary\n",
|
||||
"combined_dict = {\n",
|
||||
" \"mode\": mode,\n",
|
||||
" \"fold_group\": group_number\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Save the combined dictionary to a JSON file\n",
|
||||
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||
" json.dump(combined_dict, json_file)\n",
|
||||
" \n",
|
||||
"try:\n",
|
||||
" # Process the data and create a DatasetDict\n",
|
||||
" combined_data = DatasetDict({\n",
|
||||
" 'train': Dataset.from_list(process_df(train_data, mode=mode)),\n",
|
||||
" 'test': Dataset.from_list(process_df(test_data, mode=mode)),\n",
|
||||
" 'validation': Dataset.from_list(process_df(valid_data, mode=mode)),\n",
|
||||
" })\n",
|
||||
" # Save the DatasetDict to disk\n",
|
||||
" combined_data.save_to_disk(f\"combined_data/{mode}/{group_number}\")\n",
|
||||
" print(\"Dataset saved to 'combined_data'\")\n",
|
||||
"except Exception as e:\n",
|
||||
" print(f\"Error creating DatasetDict: {e}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,477 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# t5 training for combined concatenated outputs (thing + property) \n",
|
||||
"\n",
|
||||
"refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The mode has been set to: tn_td_unit\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d8d70681f4594917b7af4583a4237168",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map: 0%| | 0/6125 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "106e0cefe50c40f0a83371693cf48cf7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map: 0%| | 0/14719 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "952f8ec73df0418490cb43beaaf5a7df",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Map: 0%| | 0/2042 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# import data and load dataset\n",
|
||||
"from datasets import load_from_disk\n",
|
||||
"import json\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"model_name = \"t5-base\"\n",
|
||||
"train_epochs = 80\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Read the mode from the JSON file\n",
|
||||
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||
" mode_dict = json.load(json_file)\n",
|
||||
"\n",
|
||||
"# Add the model key to the dictionary\n",
|
||||
"mode_dict[\"model\"] = model_name\n",
|
||||
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
||||
"\n",
|
||||
"# Access the fold_group value\n",
|
||||
"fold_group = mode_dict.get(\"fold_group\")\n",
|
||||
"\n",
|
||||
"# Save the updated dictionary back to the JSON file\n",
|
||||
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||
" json.dump(mode_dict, json_file)\n",
|
||||
"\n",
|
||||
"# Set the mode variable from the JSON content\n",
|
||||
"mode = mode_dict.get(\"mode\", \"default_value\") # 'default_value' is a fallback if 'mode' is not found\n",
|
||||
"\n",
|
||||
"print(f\"The mode has been set to: {mode}\")\n",
|
||||
"\n",
|
||||
"# Path to saved combined_dataset\n",
|
||||
"file_path = f'combined_data/{mode}/{fold_group}'\n",
|
||||
"split_datasets = load_from_disk(file_path)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"# Define additional special tokens\n",
|
||||
"# additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\"]\n",
|
||||
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
||||
"# Add the additional special tokens to the tokenizer\n",
|
||||
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
||||
"\n",
|
||||
"max_length = 64\n",
|
||||
"\n",
|
||||
"def preprocess_function(examples):\n",
|
||||
" inputs = [ex[\"input\"] for ex in examples['translation']]\n",
|
||||
" targets = [ex[\"thing_property\"] for ex in examples['translation']]\n",
|
||||
" # text_target sets the corresponding label to inputs\n",
|
||||
" # there is no need to create a separate 'labels'\n",
|
||||
" model_inputs = tokenizer(\n",
|
||||
" inputs, text_target=targets, max_length=max_length, truncation=True\n",
|
||||
" )\n",
|
||||
" return model_inputs\n",
|
||||
"\n",
|
||||
"# map method maps preprocess_function to [train, valid, test] datasets of the datasetDict\n",
|
||||
"tokenized_datasets = split_datasets.map(\n",
|
||||
" preprocess_function,\n",
|
||||
" batched=True,\n",
|
||||
" remove_columns=split_datasets[\"train\"].column_names,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div>\n",
|
||||
" \n",
|
||||
" <progress value='3840' max='3840' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
||||
" [3840/3840 42:37, Epoch 80/80]\n",
|
||||
" </div>\n",
|
||||
" <table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: left;\">\n",
|
||||
" <th>Step</th>\n",
|
||||
" <th>Training Loss</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>500</td>\n",
|
||||
" <td>2.812300</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1000</td>\n",
|
||||
" <td>0.699300</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1500</td>\n",
|
||||
" <td>0.440900</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2000</td>\n",
|
||||
" <td>0.332100</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2500</td>\n",
|
||||
" <td>0.276500</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3000</td>\n",
|
||||
" <td>0.245900</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3500</td>\n",
|
||||
" <td>0.229300</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table><p>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
||||
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TrainOutput(global_step=3840, training_loss=0.6754856963952383, metrics={'train_runtime': 2559.4201, 'train_samples_per_second': 191.45, 'train_steps_per_second': 1.5, 'total_flos': 3.156037495934976e+16, 'train_loss': 0.6754856963952383, 'epoch': 80.0})"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"# we use the pre-trained t5-base model\n",
|
||||
"from transformers import AutoModelForSeq2SeqLM\n",
|
||||
"model_checkpoint = model_name\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
|
||||
"\n",
|
||||
"# data collator\n",
|
||||
"from transformers import DataCollatorForSeq2Seq\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# evaluation \n",
|
||||
"import evaluate\n",
|
||||
"metric = evaluate.load(\"sacrebleu\")\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" preds, labels = eval_preds\n",
|
||||
" # In case the model returns more than the prediction logits\n",
|
||||
" if isinstance(preds, tuple):\n",
|
||||
" preds = preds[0]\n",
|
||||
"\n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" # Replace -100s in the labels as we can't decode them\n",
|
||||
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" # Some simple post-processing\n",
|
||||
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
|
||||
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
|
||||
"\n",
|
||||
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
||||
" return {\"bleu\": result[\"score\"]}\n",
|
||||
"\n",
|
||||
"from transformers import Seq2SeqTrainingArguments\n",
|
||||
"\n",
|
||||
"# load environment variables to disable GPU p2p mode for multi-gpu training without p2p mode\n",
|
||||
"# not required for single-gpu training\n",
|
||||
"import os\n",
|
||||
"os.environ['NCCL_P2P_DISABLE'] = '1'\n",
|
||||
"os.environ['NCCL_IB_DISABLE'] = '1'\n",
|
||||
"\n",
|
||||
"args = Seq2SeqTrainingArguments(\n",
|
||||
" f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\",\n",
|
||||
" evaluation_strategy=\"no\",\n",
|
||||
" # logging_dir=\"tensorboard-log\",\n",
|
||||
" # logging_strategy=\"epoch\",\n",
|
||||
" save_strategy=\"epoch\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=32,\n",
|
||||
" per_device_eval_batch_size=64,\n",
|
||||
" auto_find_batch_size=True,\n",
|
||||
" ddp_find_unused_parameters=False,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" save_total_limit=1,\n",
|
||||
" num_train_epochs=train_epochs,\n",
|
||||
" predict_with_generate=True,\n",
|
||||
" bf16=True,\n",
|
||||
" push_to_hub=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"from transformers import Seq2SeqTrainer\n",
|
||||
"\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model,\n",
|
||||
" args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Train the model\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,447 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Goal: end to end inference and evaluation\n",
|
||||
"\n",
|
||||
"given a csv, make predictions and evaluate predictions, then return results in a csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The mode has been set to: tn_td_unit t5-base\n",
|
||||
"Using model checkpoint: train_1_t5-base_tn_td_unit_80/checkpoint-3840\n",
|
||||
"Columns in df_org:\n",
|
||||
"['thing', 'property', 'ships_idx', 'tag_name', 'tag_description', 'signal_type', 'min', 'max', 'unit', 'data_type', 'thing_pattern', 'property_pattern', 'pattern', 'MDM', 'org_tag_description']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"# Read the mode from the JSON file\n",
|
||||
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||
" mode_dict = json.load(json_file)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Set the mode variable from the JSON content\n",
|
||||
"mode = mode_dict.get(\"mode\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
||||
"model_name = mode_dict.get(\"model\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
||||
"train_epochs = mode_dict.get(\"train_epochs\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
||||
"fold_group = mode_dict.get(\"fold_group\", \"none\") # 'default_value' is a fallback if 'mode' is not found\n",
|
||||
"\n",
|
||||
"print(f\"The mode has been set to: {mode} {model_name}\")\n",
|
||||
"\n",
|
||||
"# Define the base directory where checkpoints are stored\n",
|
||||
"base_dir = f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\"\n",
|
||||
"\n",
|
||||
"# List all subdirectories in the base directory\n",
|
||||
"subdirectories = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]\n",
|
||||
"\n",
|
||||
"# Filter for checkpoint directories that match the pattern \"checkpoint-\"\n",
|
||||
"checkpoints = [d for d in subdirectories if d.startswith(\"checkpoint-\")]\n",
|
||||
"\n",
|
||||
"# Select the latest checkpoint (the one with the highest number)\n",
|
||||
"if checkpoints:\n",
|
||||
" latest_checkpoint = checkpoints[0]\n",
|
||||
" model_checkpoint = os.path.join(base_dir, latest_checkpoint)\n",
|
||||
" print(f\"Using model checkpoint: {model_checkpoint}\")\n",
|
||||
"else:\n",
|
||||
" print(\"No checkpoints were found.\")\n",
|
||||
" model_checkpoint = None # Handle this case as needed\n",
|
||||
"\n",
|
||||
"# Load the data\n",
|
||||
"data_path = f\"../../data_preprocess/dataset/{fold_group}/test.csv\" # Adjust the CSV file path as necessary\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" df = pd.read_csv(data_path)\n",
|
||||
"except UnicodeDecodeError:\n",
|
||||
" df = pd.read_csv(data_path, encoding='ISO-8859-1')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Drop rows where 'tag_description' is NaN and reset the index\n",
|
||||
"df = df.dropna(subset=['tag_description']).reset_index(drop=True)\n",
|
||||
"\n",
|
||||
"# Preserve df_org\n",
|
||||
"df_org = df.copy()\n",
|
||||
"\n",
|
||||
"# Print the column names of df_org\n",
|
||||
"print(\"Columns in df_org:\")\n",
|
||||
"print(df_org.columns.tolist())\n",
|
||||
"\n",
|
||||
"selected_columns = ['thing', 'property', 'tag_description', 'min', 'max', 'MDM', 'pattern']\n",
|
||||
"df[selected_columns] = df[selected_columns].astype(\"string\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The test_dataset contains 14718 items.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import Dataset\n",
|
||||
"\n",
|
||||
"def process_df(df, mode='only_td'):\n",
|
||||
" output_list = []\n",
|
||||
" for _, row in df.iterrows():\n",
|
||||
" try:\n",
|
||||
" if mode == 'only_td':\n",
|
||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||
" elif mode == 'tn_td':\n",
|
||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
||||
" elif mode == 'tn_td_min_max':\n",
|
||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
||||
" elif mode == 'td_min_max':\n",
|
||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\" \n",
|
||||
" elif mode == 'td_unit':\n",
|
||||
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
||||
" elif mode == 'tn_td_unit':\n",
|
||||
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Invalid mode specified\")\n",
|
||||
"\n",
|
||||
" output_list.append({\n",
|
||||
" 'translation': {\n",
|
||||
" 'ships_idx': row['ships_idx'],\n",
|
||||
" 'input': input_str,\n",
|
||||
" 'thing_property': f\"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>\",\n",
|
||||
" 'answer_thing': f\"{row['thing']}\",\n",
|
||||
" 'answer_property': f\"{row['property']}\",\n",
|
||||
" 'MDM': f\"{row['MDM']}\",\n",
|
||||
" }\n",
|
||||
" })\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"Error processing row: {row}\")\n",
|
||||
" print(f\"Exception: {e}\")\n",
|
||||
" return output_list\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Process the DataFrame\n",
|
||||
"processed_data = process_df(df, mode=mode)\n",
|
||||
"\n",
|
||||
"# Create a Dataset object\n",
|
||||
"test_dataset = Dataset.from_list(processed_data)\n",
|
||||
"\n",
|
||||
"# Print the number of items in the dataset\n",
|
||||
"print(f\"The test_dataset contains {len(test_dataset)} items.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers.pipelines.pt_utils import KeyDataset\n",
|
||||
"from transformers import pipeline\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import os\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors=\"pt\")\n",
|
||||
"# Define additional special tokens\n",
|
||||
"# additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\"]\n",
|
||||
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
||||
"\n",
|
||||
"# Add the additional special tokens to the tokenizer\n",
|
||||
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
||||
"# tokenizer.add_special_tokens({'sep_token': \"<SEP>\"})\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"translation_XX_to_YY\", model=model_checkpoint, tokenizer=tokenizer, return_tensors=True, max_length=128, device=0)\n",
|
||||
"\n",
|
||||
"# check what token-ids the special tokens are\n",
|
||||
"# tokenizer.encode(\"<THING_START><THING_END><PROPERTY_START><PROPERTY_END>\")\n",
|
||||
"\n",
|
||||
"def extract_seq(tokens, start_value, end_value):\n",
|
||||
" if start_value not in tokens or end_value not in tokens:\n",
|
||||
" return None # Or handle this case according to your requirements\n",
|
||||
" start_id = tokens.index(start_value)\n",
|
||||
" end_id = tokens.index(end_value)\n",
|
||||
"\n",
|
||||
" return tokens[start_id+1:end_id]\n",
|
||||
"\n",
|
||||
"# problem, what if end tokens are not in?\n",
|
||||
"def process_tensor_output(output):\n",
|
||||
" tokens = output[0]['translation_token_ids'].tolist()\n",
|
||||
" thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>\n",
|
||||
" property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>\n",
|
||||
" p_thing = None\n",
|
||||
" p_property = None\n",
|
||||
" if (thing_seq is not None):\n",
|
||||
" p_thing = tokenizer.decode(thing_seq)\n",
|
||||
" if (property_seq is not None):\n",
|
||||
" p_property = tokenizer.decode(property_seq)\n",
|
||||
" return p_thing, p_property"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"making inference on test set\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"14718it [00:44, 330.24it/s] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"inference done\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"p_thing_list = []\n",
|
||||
"p_property_list = []\n",
|
||||
"print(\"making inference on test set\")\n",
|
||||
"for out in tqdm(pipe(KeyDataset(test_dataset[\"translation\"], \"input\"), batch_size=256)):\n",
|
||||
" p_thing, p_property = process_tensor_output(out)\n",
|
||||
" p_thing_list.append(p_thing)\n",
|
||||
" p_property_list.append(p_property)\n",
|
||||
"print(\"inference done\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Thing prediction accuracy: 0.9895314057826521\n",
|
||||
"Correct thing predictions: 1985, Incorrect thing predictions: 21\n",
|
||||
"Property prediction accuracy: 0.9661016949152542\n",
|
||||
"Correct property predictions: 1938, Incorrect property predictions: 12780\n",
|
||||
"total accuracy: 0.9596211365902293\n",
|
||||
"Correct total predictions: 1925, Incorrect total predictions: 81\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"answer_thing = [item['answer_thing'] for item in test_dataset[\"translation\"]]\n",
|
||||
"answer_property = [item['answer_property'] for item in test_dataset[\"translation\"]]\n",
|
||||
"mdm_list = [item['MDM'] for item in test_dataset[\"translation\"]]\n",
|
||||
"\n",
|
||||
"mdm_count = 0\n",
|
||||
"for i in range(len(mdm_list)):\n",
|
||||
" if(mdm_list[i] == \"True\"):mdm_count = mdm_count + 1 \n",
|
||||
"\n",
|
||||
"def correctness_test(input, reference, mdm_list):\n",
|
||||
" assert(len(input) == len(reference))\n",
|
||||
" correctness_list = []\n",
|
||||
" for i in range(len(input)):\n",
|
||||
" if(mdm_list[i] == \"True\"):\n",
|
||||
" correctness_list.append(input[i] == reference[i])\n",
|
||||
" else:correctness_list.append(False)\n",
|
||||
" return correctness_list\n",
|
||||
"\n",
|
||||
"# Compare with answer to evaluate correctness\n",
|
||||
"thing_correctness = correctness_test(p_thing_list, answer_thing, mdm_list)\n",
|
||||
"property_correctness = correctness_test(p_property_list, answer_property, mdm_list)\n",
|
||||
"\n",
|
||||
"correctness_mdm = []\n",
|
||||
"for i in range(len(mdm_list)):\n",
|
||||
" if(thing_correctness[i] & property_correctness[i]):\n",
|
||||
" correctness_mdm.append(True)\n",
|
||||
" else: \n",
|
||||
" correctness_mdm.append(False)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"# Calculate accuracy\n",
|
||||
"thing_accuracy = sum(thing_correctness) / mdm_count\n",
|
||||
"property_accuracy = sum(property_correctness) / mdm_count\n",
|
||||
"total_accuracy = sum(correctness_mdm) / mdm_count\n",
|
||||
"\n",
|
||||
"# Count True/False values\n",
|
||||
"thing_true_count = thing_correctness.count(True)\n",
|
||||
"thing_false_count = 0\n",
|
||||
"for i in range(len(thing_correctness)):\n",
|
||||
" if mdm_list[i] == \"True\" and thing_correctness[i] == False:\n",
|
||||
" thing_false_count += 1\n",
|
||||
"\n",
|
||||
"property_true_count = property_correctness.count(True)\n",
|
||||
"property_false_count = property_correctness.count(False)\n",
|
||||
"total_true_count = correctness_mdm.count(True)\n",
|
||||
"total_false_count = mdm_count - correctness_mdm.count(True)\n",
|
||||
"\n",
|
||||
"# Print results\n",
|
||||
"print(\"Thing prediction accuracy:\", thing_accuracy)\n",
|
||||
"print(f\"Correct thing predictions: {thing_true_count}, Incorrect thing predictions: {thing_false_count}\")\n",
|
||||
"print(\"Property prediction accuracy:\", property_accuracy)\n",
|
||||
"print(f\"Correct property predictions: {property_true_count}, Incorrect property predictions: {property_false_count}\")\n",
|
||||
"print(\"total accuracy:\", total_accuracy)\n",
|
||||
"print(f\"Correct total predictions: {total_true_count}, Incorrect total predictions: {total_false_count}\")\n",
|
||||
"\n",
|
||||
"# Create a DataFrame with the results\n",
|
||||
"dict = {\n",
|
||||
" 'p_thing': p_thing_list,\n",
|
||||
" 'p_property': p_property_list,\n",
|
||||
" 'p_thing_correct': thing_correctness,\n",
|
||||
" 'p_property_correct': property_correctness\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"df_pred = pd.DataFrame(dict)\n",
|
||||
"\n",
|
||||
"# Read the mode from the JSON file\n",
|
||||
"with open(\"mode.json\", \"r\") as json_file:\n",
|
||||
" mode_dict = json.load(json_file)\n",
|
||||
"\n",
|
||||
"# Add the model key to the dictionary\n",
|
||||
"mode_dict[\"model\"] = model_name\n",
|
||||
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
||||
"\n",
|
||||
"# Save the updated dictionary back to the JSON file\n",
|
||||
"with open(\"mode.json\", \"w\") as json_file:\n",
|
||||
" json.dump(mode_dict, json_file)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Check if the file exists and is not empty\n",
|
||||
"if os.path.exists(\"results.json\") and os.path.getsize(\"results.json\") > 0:\n",
|
||||
" # Read the existing results.json file\n",
|
||||
" with open(\"results.json\", \"r\") as json_file:\n",
|
||||
" try:\n",
|
||||
" results_dict = json.load(json_file)\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" results_dict = {}\n",
|
||||
"else:\n",
|
||||
" results_dict = {}\n",
|
||||
"\n",
|
||||
"# Add the new model_checkpoint key with the accuracy values as an object\n",
|
||||
"\n",
|
||||
"model_key = model_checkpoint \n",
|
||||
"\n",
|
||||
"results_dict[model_key] = {\n",
|
||||
" \"thing_accuracy\": thing_accuracy,\n",
|
||||
" \"thing_true\": thing_true_count,\n",
|
||||
" \"thing_false\": thing_false_count,\n",
|
||||
" \"property_accuracy\": property_accuracy,\n",
|
||||
" \"property_true\": property_true_count,\n",
|
||||
" \"property_false\": property_false_count,\n",
|
||||
" \"total_accuracy\": total_accuracy,\n",
|
||||
" \"total_true\": total_true_count,\n",
|
||||
" \"total_false\": total_false_count \n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Save the updated dictionary back to the results.json file\n",
|
||||
"with open(\"results.json\", \"w\") as json_file:\n",
|
||||
" json.dump(results_dict, json_file, indent=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Updated data saved to ../0.result/1/test_p.csv\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Create a DataFrame with the results\n",
|
||||
"df_pred = pd.DataFrame({\n",
|
||||
" 'p_thing': p_thing_list,\n",
|
||||
" 'p_property': p_property_list,\n",
|
||||
" 'p_thing_correct': thing_correctness,\n",
|
||||
" 'p_property_correct': property_correctness,\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"# Merge predictions with the original DataFrame (df_org)\n",
|
||||
"df_org['p_thing'] = df_pred['p_thing']\n",
|
||||
"df_org['p_property'] = df_pred['p_property']\n",
|
||||
"df_org['p_thing_correct'] = df_pred['p_thing_correct']\n",
|
||||
"df_org['p_property_correct'] = df_pred['p_property_correct']\n",
|
||||
"df_org['p_correct'] = df_pred['p_thing_correct'] & df_org['p_property_correct']\n",
|
||||
"\n",
|
||||
"df_master = pd.read_csv('../../data_import/data_model_master_export.csv')\n",
|
||||
"\n",
|
||||
"df_org['pattern'] = df_org['thing'].str.replace(r'\\d', '#', regex=True) + \" \" + df_org['property'].str.replace(r'\\d', '#', regex=True)\n",
|
||||
"df_org['p_pattern'] = df_org['p_thing'].str.replace(r'\\d', '#', regex=True) + \" \" + df_org['p_property'].str.replace(r'\\d', '#', regex=True)\n",
|
||||
"df_master['master_pattern'] = df_master['thing'] + \" \" + df_master['property']\n",
|
||||
"\n",
|
||||
"# Create a set of unique patterns from master for fast lookup\n",
|
||||
"master_patterns = set(df_master['master_pattern'])\n",
|
||||
"df_org['p_MDM'] = df_org['p_pattern'].apply(lambda x: x in master_patterns)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"output_path = f\"../0.result/{fold_group}/test_p.csv\"\n",
|
||||
"debug_output_path = f\"0.dresult/{fold_group}/test_p.csv\"\n",
|
||||
"\n",
|
||||
"# 폴더가 없으면 생성\n",
|
||||
"os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
||||
"df_org.to_csv(output_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"os.makedirs(os.path.dirname(debug_output_path), exist_ok=True)\n",
|
||||
"df_org.to_csv(debug_output_path, index=False, encoding='utf-8-sig')\n",
|
||||
"\n",
|
||||
"print(f\"Updated data saved to {output_path}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue