hipom_data_mapping/data_preprocess/split_data.ipynb

442 lines
19 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Group Allocation:\n",
"Group 1: Ships_idx = [1025, 1032, 1042, 1046, 1023, 1037, 1024, 1014, 1019, 1008], PD type = 529, PD = 1992, SD = 9855\n",
"Group 2: Ships_idx = [1003, 1028, 1018, 1020, 1033, 1050, 1030, 1051, 1004, 1036], PD type = 528, PD = 2113, SD = 13074\n",
"Group 3: Ships_idx = [1016, 1026, 1043, 1031, 1012, 1021, 1000, 1011, 1006, 1005, 1038], PD type = 521, PD = 2140, SD = 10722\n",
"Group 4: Ships_idx = [1047, 1049, 1010, 1027, 1013, 1022, 1048, 1017, 1045, 1007], PD type = 521, PD = 2102, SD = 15451\n",
"Group 5: Ships_idx = [1039, 1035, 1044, 1009, 1015, 1040, 1001, 1034, 1041, 1002, 1029], PD type = 500, PD = 2183, SD = 12969\n"
]
}
],
"source": [
"import pandas as pd\n",
"from collections import defaultdict\n",
"\n",
"# Function to calculate the number of unique combinations and total count for each ship\n",
"def calculate_ship_count(group):\n",
" ship_count = group.groupby('ships_idx')['thing_property'].agg(['nunique', 'size']).reset_index()\n",
" ship_count.columns = ['ships_idx', 'comb_count', 'total_count']\n",
" return ship_count\n",
"\n",
"# Function to calculate the combination count and total count for a group\n",
"def calculate_group_count(group):\n",
" comb_count = group['thing_property'].nunique()\n",
" total_count = group['thing_property'].size\n",
" return comb_count, total_count\n",
"\n",
"# Function to calculate the increase in combination count when a ship is added to a group\n",
"def calculate_comb_count_increase(groups, g, ship_idx, mdm):\n",
" temp_groups = defaultdict(list, {k: v.copy() for k, v in groups.items()})\n",
" temp_groups[g].append(ship_idx)\n",
" \n",
" group_ships = temp_groups[g]\n",
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
" \n",
" new_comb_count, _ = calculate_group_count(group_data)\n",
" \n",
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
" current_comb_count, _ = calculate_group_count(current_group_data)\n",
" \n",
" increase = new_comb_count - current_comb_count\n",
" \n",
" return increase\n",
"\n",
"# Function to calculate the increase in total count when a ship is added to a group\n",
"def calculate_total_count_increase(groups, g, ship_idx, mdm):\n",
" temp_groups = defaultdict(list, {k: v.copy() for k, v in groups.items()})\n",
" temp_groups[g].append(ship_idx)\n",
" \n",
" group_ships = temp_groups[g]\n",
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
" \n",
" _, new_total_count = calculate_group_count(group_data)\n",
" \n",
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
" _, current_total_count = calculate_group_count(current_group_data)\n",
" \n",
" increase = new_total_count - current_total_count\n",
" \n",
" return increase\n",
"\n",
"# Function to find the ship that will bring the total count closest to the target\n",
"def find_closest_total_count_ship(groups, g, remaining_ships, mdm, target_total_count):\n",
" total_count_differences = []\n",
"\n",
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
" _, current_total_count = calculate_group_count(current_group_data)\n",
"\n",
" for ship_idx in remaining_ships:\n",
" increase = calculate_total_count_increase(groups, g, ship_idx, mdm)\n",
" new_total_count = current_total_count + increase\n",
" difference = abs(target_total_count - new_total_count)\n",
" total_count_differences.append((ship_idx, difference, increase))\n",
"\n",
" if not total_count_differences:\n",
" return None, 0\n",
" \n",
" closest_ship = min(total_count_differences, key=lambda x: x[1])\n",
" selected_ship_idx, _, selected_increase = closest_ship\n",
"\n",
" return selected_ship_idx, selected_increase\n",
"\n",
"# Function to find the ship that gives the maximum increase in combination count\n",
"def find_max_increase_ship(groups, g, remaining_ships, mdm):\n",
" comb_count_increase = []\n",
"\n",
" for ship_idx in remaining_ships:\n",
" increase = calculate_comb_count_increase(groups, g, ship_idx, mdm)\n",
" comb_count_increase.append((ship_idx, increase))\n",
"\n",
" max_increase_ship = max(comb_count_increase, key=lambda x: x[1])\n",
" selected_ship_idx, max_increase = max_increase_ship\n",
" \n",
" return selected_ship_idx, max_increase\n",
"\n",
"# Function to find the ship that will bring the combination count closest to the target\n",
"def find_closest_comb_count_ship(groups, g, remaining_ships, mdm, target_comb_count):\n",
" comb_count_differences = []\n",
"\n",
" current_group_data = mdm[mdm['ships_idx'].isin(groups[g])]\n",
" current_comb_count, _ = calculate_group_count(current_group_data)\n",
"\n",
" for ship_idx in remaining_ships:\n",
" increase = calculate_comb_count_increase(groups, g, ship_idx, mdm)\n",
" new_comb_count = current_comb_count + increase\n",
" difference = abs(target_comb_count - new_comb_count)\n",
" comb_count_differences.append((ship_idx, difference, increase))\n",
"\n",
" if not comb_count_differences:\n",
" return None, 0\n",
"\n",
" closest_ship = min(comb_count_differences, key=lambda x: x[1])\n",
" selected_ship_idx, _, selected_increase = closest_ship\n",
"\n",
" return selected_ship_idx, selected_increase\n",
"\n",
"# Function to find the group with the maximum combination count\n",
"def find_group_with_max_comb_count(groups, mdm):\n",
" max_comb_count = -1\n",
" max_group_idx = -1\n",
"\n",
" for g in range(len(groups)):\n",
" group_ships = groups[g]\n",
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
" comb_count, _ = calculate_group_count(group_data)\n",
" \n",
" if comb_count > max_comb_count:\n",
" max_comb_count = comb_count\n",
" max_group_idx = g\n",
"\n",
" return max_group_idx, max_comb_count\n",
"\n",
"# Function to find the group with the maximum total count\n",
"def find_group_with_max_total_count(groups, mdm):\n",
" max_total_count = -1\n",
" max_group_idx = -1\n",
"\n",
" for g in range(len(groups)):\n",
" group_ships = groups[g]\n",
" group_data = mdm[mdm['ships_idx'].isin(group_ships)]\n",
" _, total_count = calculate_group_count(group_data)\n",
" \n",
" if total_count > max_total_count:\n",
" max_total_count = total_count\n",
" max_group_idx = g\n",
"\n",
" return max_group_idx, max_total_count\n",
"\n",
"import pandas as pd\n",
"from collections import defaultdict\n",
"\n",
"# Load the CSV file\n",
"data_file_path = 'preprocessed_data.csv'\n",
"data = pd.read_csv(data_file_path)\n",
"\n",
"# Filter the data where MDM is True\n",
"mdm_true = data[data['MDM'] == True].copy() # .copy()를 사용하여 명시적으로 복사본 생성\n",
"mdm_all = data.copy()\n",
"\n",
"# Create a new column combining 'thing' and 'property'\n",
"mdm_true.loc[:, 'thing_property'] = mdm_true['thing'] + '_' + mdm_true['property']\n",
"mdm_all.loc[:, 'thing_property'] = mdm_all['thing'] + '_' + mdm_all['property']\n",
"\n",
"# Initial setup for groups\n",
"ship_count = calculate_ship_count(mdm_true)\n",
"num_groups = 5\n",
"groups = defaultdict(list)\n",
"\n",
"# Sort ships by combination count in descending order\n",
"sorted_ships = ship_count.sort_values(by='comb_count', ascending=False)\n",
"\n",
"# Assign the first 5 ships to the groups\n",
"for i in range(num_groups):\n",
" groups[i].append(sorted_ships.iloc[i]['ships_idx'])\n",
"\n",
"remaining_ships = sorted_ships.iloc[num_groups:]['ships_idx'].values\n",
"\n",
"# Allocate remaining ships to the groups\n",
"while len(remaining_ships) > 0:\n",
" group_comb_counts = []\n",
" for g in range(num_groups):\n",
" group_ships = groups[g]\n",
" group_data = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
" comb_count, _ = calculate_group_count(group_data)\n",
" group_comb_counts.append((g, comb_count))\n",
"\n",
" group_comb_counts.sort(key=lambda x: x[1])\n",
" \n",
" remaining_group = []\n",
" for g, _ in group_comb_counts:\n",
" if len(remaining_ships) == 0:\n",
" break\n",
" \n",
" if group_comb_counts.index((g, _)) == 0:\n",
" selected_ship_idx, comb_increase = find_max_increase_ship(groups, g, remaining_ships, mdm_true)\n",
" \n",
" else:\n",
" max_group_idx, max_comb_count = find_group_with_max_comb_count(groups, mdm_true)\n",
" selected_ship_idx, comb_increase = find_closest_comb_count_ship(groups, g, remaining_ships, mdm_true, max_comb_count)\n",
"\n",
" if comb_increase == 0:\n",
" remaining_group.append(g)\n",
" else:\n",
" groups[g].append(selected_ship_idx)\n",
" remaining_ships = remaining_ships[remaining_ships != selected_ship_idx]\n",
"\n",
" for g in remaining_group:\n",
" if len(remaining_ships) == 0:\n",
" break\n",
" max_group_idx, max_total_count = find_group_with_max_total_count(groups, mdm_true)\n",
" selected_ship_idx, count_increase = find_closest_total_count_ship(groups, g, remaining_ships, mdm_true, max_total_count)\n",
" if selected_ship_idx is not None:\n",
" groups[g].append(selected_ship_idx)\n",
" remaining_ships = remaining_ships[remaining_ships != selected_ship_idx]\n",
"\n",
"# Calculate comb_count for each group and store it in a list\n",
"group_comb_counts = []\n",
"for g in range(num_groups):\n",
" group_ships = groups[g]\n",
" group_data_true = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
" comb_count, total_count = calculate_group_count(group_data_true)\n",
"\n",
" # Calculate total count including MDM=False\n",
" group_data_all = mdm_all[mdm_all['ships_idx'].isin(group_ships)]\n",
" _, total_count_all = calculate_group_count(group_data_all)\n",
" \n",
" group_comb_counts.append((g, comb_count, total_count_all))\n",
"\n",
"# Sort the groups by comb_count in descending order\n",
"group_comb_counts.sort(key=lambda x: x[1], reverse=True)\n",
"\n",
"# Reorder the groups dictionary based on the sorted order\n",
"sorted_groups = defaultdict(list)\n",
"for i, (g, _, _) in enumerate(group_comb_counts):\n",
" sorted_groups[i] = groups[g]\n",
"\n",
"# Final output of group allocation\n",
"print(\"Final Group Allocation:\")\n",
"for g in range(num_groups):\n",
" group_ships = sorted_groups[g]\n",
" group_data_true = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
" comb_count, total_count = calculate_group_count(group_data_true)\n",
"\n",
" # Calculate total count including MDM=False\n",
" group_data_all = mdm_all[mdm_all['ships_idx'].isin(group_ships)]\n",
" _, total_count_all = calculate_group_count(group_data_all)\n",
"\n",
" print(f\"Group {g + 1}: Ships_idx = {group_ships}, PD type = {comb_count}, PD = {total_count}, SD = {total_count_all}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CSV file has been generated: 'combined_group_allocation.csv'\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn.model_selection import GroupKFold\n",
"\n",
"# Prepare data for custom group allocation (BGKF)\n",
"comb_counts = []\n",
"total_counts = []\n",
"ship_counts = []\n",
"custom_results = []\n",
"\n",
"for g in range(num_groups):\n",
" group_ships = groups[g]\n",
" group_data_true = mdm_true[mdm_true['ships_idx'].isin(group_ships)]\n",
" comb_count, total_count = calculate_group_count(group_data_true)\n",
" \n",
" # Calculate total count including MDM=False\n",
" group_data_all = mdm_all[mdm_all['ships_idx'].isin(group_ships)]\n",
" _, total_count_all = calculate_group_count(group_data_all)\n",
" \n",
" custom_results.append({\n",
" 'Group': g + 1,\n",
" 'Allocation': 'BGKF',\n",
" 'Comb_count': comb_count,\n",
" 'Total_count': total_count,\n",
" 'Total_count_all': total_count_all,\n",
" 'Ship_count': len(group_ships),\n",
" 'Ships_idx': list(group_ships)\n",
" })\n",
"\n",
"# Sort the custom group allocation by comb_count in descending order\n",
"custom_results.sort(key=lambda x: x['Comb_count'], reverse=True)\n",
"\n",
"# Adjust group numbers after sorting\n",
"for i, result in enumerate(custom_results):\n",
" result['Group'] = i + 1\n",
"\n",
"# Prepare data for GroupKFold allocation (GKF)\n",
"gkf = GroupKFold(n_splits=5)\n",
"gkf_results = []\n",
"\n",
"for i, (train_idx, test_idx) in enumerate(gkf.split(mdm_true, groups=mdm_true['ships_idx'])):\n",
" test_group = mdm_true.iloc[test_idx]\n",
" comb_count, total_count = calculate_group_count(test_group)\n",
" \n",
" # Calculate total count including MDM=False\n",
" test_group_ships = test_group['ships_idx'].unique()\n",
" test_group_all = mdm_all[mdm_all['ships_idx'].isin(test_group_ships)]\n",
" _, total_count_all = calculate_group_count(test_group_all)\n",
" \n",
" gkf_results.append({\n",
" 'Group': i + 1,\n",
" 'Allocation': 'GKF',\n",
" 'Comb_count': comb_count,\n",
" 'Total_count': total_count,\n",
" 'Total_count_all': total_count_all,\n",
" 'Ship_count': test_group['ships_idx'].nunique(),\n",
" 'Ships_idx': list(test_group['ships_idx'].unique())\n",
" })\n",
"\n",
"# Sort the GKF allocation by comb_count in descending order\n",
"gkf_results.sort(key=lambda x: x['Comb_count'], reverse=True)\n",
"\n",
"# Adjust group numbers after sorting\n",
"for i, result in enumerate(gkf_results):\n",
" result['Group'] = i + 1\n",
"\n",
"# Combine BGKF and GKF results into one DataFrame\n",
"combined_results = custom_results + gkf_results\n",
"combined_df = pd.DataFrame(combined_results)\n",
"\n",
"# Output the combined results to a single CSV file\n",
"combined_df.to_csv('combined_group_allocation.csv', index=False)\n",
"\n",
"print(\"CSV file has been generated: 'combined_group_allocation.csv'\")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Group 1 datasets saved in dataset/1\n",
"Group 2 datasets saved in dataset/2\n",
"Group 3 datasets saved in dataset/3\n",
"Group 4 datasets saved in dataset/4\n",
"Group 5 datasets saved in dataset/5\n"
]
}
],
"source": [
"import os\n",
"import pandas as pd\n",
"from sklearn.model_selection import KFold\n",
"\n",
"def save_datasets_for_group(groups, mdm, data, output_dir='dataset', n_splits=4):\n",
" for i in range(len(groups)):\n",
" group_folder = os.path.join(output_dir, str(i + 1))\n",
" os.makedirs(group_folder, exist_ok=True)\n",
" \n",
" # Create the test dataset by including only group i\n",
" test_group_ships = groups[i]\n",
" test_data = mdm[mdm['ships_idx'].isin(test_group_ships)]\n",
" \n",
" # Extract corresponding entries from the external test dataset\n",
" test_all_data = data[data['ships_idx'].isin(test_group_ships)]\n",
" \n",
" # Create the train dataset by excluding group i\n",
" train_group_ships = []\n",
" for g in range(len(groups)):\n",
" if g != i:\n",
" train_group_ships.extend(groups[g])\n",
" train_data = mdm[mdm['ships_idx'].isin(train_group_ships)]\n",
" \n",
" # Use KFold to split train_data into train and valid datasets\n",
" kf_inner = KFold(n_splits=n_splits, shuffle=True, random_state=42)\n",
" train_idx_inner, valid_idx_inner = next(kf_inner.split(train_data))\n",
" \n",
" final_train_data = train_data.iloc[train_idx_inner]\n",
" valid_data = train_data.iloc[valid_idx_inner]\n",
" \n",
" # Combine train and valid data to create train_all\n",
" train_all_data = pd.concat([final_train_data, valid_data])\n",
" \n",
" # Save datasets to CSV files\n",
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
" test_all_file_path = os.path.join(group_folder, 'test_all.csv')\n",
" train_all_file_path = os.path.join(group_folder, 'train_all.csv')\n",
" \n",
" final_train_data.to_csv(train_file_path, index=False, encoding='utf-8-sig')\n",
" valid_data.to_csv(valid_file_path, index=False, encoding='utf-8-sig')\n",
" # test_data.to_csv(test_file_path, index=False, encoding='utf-8-sig')\n",
" test_all_data.to_csv(test_file_path, index=False, encoding='utf-8-sig')\n",
" train_all_data.to_csv(train_all_file_path, index=False, encoding='utf-8-sig')\n",
" \n",
" print(f\"Group {i + 1} datasets saved in {group_folder}\")\n",
"\n",
"# Example usage:\n",
"save_datasets_for_group(groups, mdm_true, data, n_splits=4)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}