199 lines
7.1 KiB
Plaintext
199 lines
7.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loaded data for group 1:\n",
|
|
"Train data shape: (6125, 16)\n",
|
|
"Valid data shape: (2042, 16)\n",
|
|
"Test data shape: (14719, 15)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import os\n",
|
|
"# Example usage:1\n",
|
|
"group_number = 1 # You can change this to any group number you want to load (1, 2, 3, 4, or 5)\n",
|
|
"\n",
|
|
"# Select the mode for processing\n",
|
|
"mode = 'tn_td_unit' # Change this to 'only_td', 'tn_td', etc., as needed\n",
|
|
"\n",
|
|
"def load_group_data(group_number):\n",
|
|
" # Define the folder path based on the group number\n",
|
|
" group_folder = os.path.join('../../data_preprocess/dataset', str(group_number))\n",
|
|
" \n",
|
|
" # Define file paths for train, valid, and test datasets\n",
|
|
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
|
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
|
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
|
" \n",
|
|
" # Check if the files exist\n",
|
|
" if not os.path.exists(train_file_path) or not os.path.exists(valid_file_path) or not os.path.exists(test_file_path):\n",
|
|
" raise FileNotFoundError(f\"One or more files for group {group_number} do not exist.\")\n",
|
|
" \n",
|
|
" # Load the CSV files into DataFrames\n",
|
|
" train_data = pd.read_csv(train_file_path)\n",
|
|
" valid_data = pd.read_csv(valid_file_path)\n",
|
|
" test_data = pd.read_csv(test_file_path)\n",
|
|
" \n",
|
|
" return train_data, valid_data, test_data\n",
|
|
"\n",
|
|
"\n",
|
|
"try:\n",
|
|
" train_data, valid_data, test_data = load_group_data(group_number)\n",
|
|
" print(f\"Loaded data for group {group_number}:\")\n",
|
|
" print(f\"Train data shape: {train_data.shape}\")\n",
|
|
" print(f\"Valid data shape: {valid_data.shape}\")\n",
|
|
" print(f\"Test data shape: {test_data.shape}\")\n",
|
|
"except FileNotFoundError as e:\n",
|
|
" print(e)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "313f98ef12eb442bac319282e5ffe5d6",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Saving the dataset (0/1 shards): 0%| | 0/6125 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "0c1834a4e7264a969085ad609320fdd6",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Saving the dataset (0/1 shards): 0%| | 0/14719 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "464f88daab334658aac93305ea6dac71",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Saving the dataset (0/1 shards): 0%| | 0/2042 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Dataset saved to 'combined_data'\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import json\n",
|
|
"from datasets import Dataset, DatasetDict\n",
|
|
"\n",
|
|
"# Function to process DataFrame based on mode\n",
|
|
"def process_df(df, mode='only_td'):\n",
|
|
" output_list = []\n",
|
|
" for idx, row in df.iterrows():\n",
|
|
" try:\n",
|
|
" if mode == 'only_td':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
|
" elif mode == 'tn_td':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
|
" elif mode == 'tn_td_min_max':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
|
" elif mode == 'td_min_max':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\" \n",
|
|
" elif mode == 'td_unit':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
|
" elif mode == 'tn_td_unit':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\" \n",
|
|
" else:\n",
|
|
" raise ValueError(\"Invalid mode specified\")\n",
|
|
" \n",
|
|
" output_list.append({\n",
|
|
" 'translation': {\n",
|
|
" 'ships_idx': row['ships_idx'],\n",
|
|
" 'input': input_str,\n",
|
|
" 'thing_property': f\"<THING_START>{str(row['thing'])}<THING_END><PROPERTY_START>{str(row['property'])}<PROPERTY_END>\",\n",
|
|
" 'answer': f\"{str(row['thing'])} {str(row['property'])}\",\n",
|
|
" }\n",
|
|
" })\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing row at index {idx}: {row}\")\n",
|
|
" print(f\"Exception: {e}\")\n",
|
|
" return output_list\n",
|
|
"\n",
|
|
"\n",
|
|
"# Combine the mode and group information into a single dictionary\n",
|
|
"combined_dict = {\n",
|
|
" \"mode\": mode,\n",
|
|
" \"fold_group\": group_number\n",
|
|
"}\n",
|
|
"\n",
|
|
"# Save the combined dictionary to a JSON file\n",
|
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
|
" json.dump(combined_dict, json_file)\n",
|
|
" \n",
|
|
"try:\n",
|
|
" # Process the data and create a DatasetDict\n",
|
|
" combined_data = DatasetDict({\n",
|
|
" 'train': Dataset.from_list(process_df(train_data, mode=mode)),\n",
|
|
" 'test': Dataset.from_list(process_df(test_data, mode=mode)),\n",
|
|
" 'validation': Dataset.from_list(process_df(valid_data, mode=mode)),\n",
|
|
" })\n",
|
|
" # Save the DatasetDict to disk\n",
|
|
" combined_data.save_to_disk(f\"combined_data/{mode}/{group_number}\")\n",
|
|
" print(\"Dataset saved to 'combined_data'\")\n",
|
|
"except Exception as e:\n",
|
|
" print(f\"Error creating DatasetDict: {e}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|