149 lines
5.6 KiB
Plaintext
149 lines
5.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "30a1ff83e388495ab06f4b8177746d4b",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Saving the dataset (0/1 shards): 0%| | 0/6260 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "3d009461ac044864b674dc59898160b2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Saving the dataset (0/1 shards): 0%| | 0/12969 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "ce28a831723d4b4698e6ce4a216c56db",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Saving the dataset (0/1 shards): 0%| | 0/2087 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Dataset saved to 'combined_data'\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"from datasets import Dataset, DatasetDict\n",
|
|
"\n",
|
|
"group_number = 5\n",
|
|
"mode = 'td_unit'\n",
|
|
"\n",
|
|
"def load_group_data(group_number):\n",
|
|
" group_folder = os.path.join('../../data_preprocess/dataset', str(group_number))\n",
|
|
" train_file_path = os.path.join(group_folder, 'train.csv')\n",
|
|
" valid_file_path = os.path.join(group_folder, 'valid.csv')\n",
|
|
" test_file_path = os.path.join(group_folder, 'test.csv')\n",
|
|
" \n",
|
|
" if not all(os.path.exists(f) for f in [train_file_path, valid_file_path, test_file_path]):\n",
|
|
" raise FileNotFoundError(f\"Files for group {group_number} do not exist.\")\n",
|
|
" \n",
|
|
" return pd.read_csv(train_file_path), pd.read_csv(valid_file_path), pd.read_csv(test_file_path)\n",
|
|
"\n",
|
|
"train_data, valid_data, test_data = load_group_data(group_number)\n",
|
|
"\n",
|
|
"def process_df(df, mode='only_td'):\n",
|
|
" output_list = []\n",
|
|
" for idx, row in df.iterrows():\n",
|
|
" try:\n",
|
|
" if mode == 'only_td':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
|
" elif mode == 'tn_td':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END>\"\n",
|
|
" elif mode == 'tn_td_min_max':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
|
" elif mode == 'td_min_max':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END>\"\n",
|
|
" elif mode == 'td_unit':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
|
" elif mode == 'tn_td_unit':\n",
|
|
" input_str = f\"<TN_START>{str(row['tag_name'])}<TN_END><TD_START>{str(row['tag_description'])}<TD_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
|
" elif mode == 'td_min_max_unit':\n",
|
|
" input_str = f\"<TD_START>{str(row['tag_description'])}<TD_END><MIN_START>{row['min']}<MIN_END><MAX_START>{row['max']}<MAX_END><UNIT_START>{str(row['unit'])}<UNIT_END>\"\n",
|
|
" else:\n",
|
|
" raise ValueError(\"Invalid mode specified\")\n",
|
|
" \n",
|
|
" output_list.append({\n",
|
|
" 'translation': {\n",
|
|
" 'ships_idx': row['ships_idx'],\n",
|
|
" 'input': input_str,\n",
|
|
" 'thing_property': f\"<THING_START>{str(row['thing'])}<THING_END><PROPERTY_START>{str(row['property'])}<PROPERTY_END>\",\n",
|
|
" 'answer': f\"{str(row['thing'])} {str(row['property'])}\",\n",
|
|
" }\n",
|
|
" })\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error processing row at index {idx}: {e}\")\n",
|
|
" return output_list\n",
|
|
"\n",
|
|
"combined_dict = {\"mode\": mode, \"fold_group\": group_number}\n",
|
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
|
" json.dump(combined_dict, json_file)\n",
|
|
"\n",
|
|
"combined_data = DatasetDict({\n",
|
|
" 'train': Dataset.from_list(process_df(train_data, mode=mode)),\n",
|
|
" 'test': Dataset.from_list(process_df(test_data, mode=mode)),\n",
|
|
" 'validation': Dataset.from_list(process_df(valid_data, mode=mode)),\n",
|
|
"})\n",
|
|
"combined_data.save_to_disk(f\"combined_data/{mode}/{group_number}\")\n",
|
|
"print(\"Dataset saved to 'combined_data'\")\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|