hipom_data_mapping/translation/t5/2.t5_train.ipynb

530 lines
44 KiB
Plaintext
Raw Normal View History

2024-08-26 19:51:11 +09:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# t5 training for combined concatenated outputs (thing + property) \n",
"\n",
"refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
2024-09-25 08:52:30 +09:00
" <progress value='3140' max='3920' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [3140/3920 05:42 < 01:25, 9.17 it/s, Epoch 64.06/80]\n",
2024-08-26 19:51:11 +09:00
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
2024-09-25 08:52:30 +09:00
" <th>Validation Loss</th>\n",
" <th>Bleu</th>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-09-25 08:52:30 +09:00
" <td>100</td>\n",
" <td>9.068100</td>\n",
" <td>1.485702</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>0.886400</td>\n",
" <td>0.219002</td>\n",
" <td>20.999970</td>\n",
" </tr>\n",
" <tr>\n",
" <td>300</td>\n",
" <td>0.302500</td>\n",
" <td>0.100100</td>\n",
" <td>50.318311</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>0.168400</td>\n",
" <td>0.053922</td>\n",
" <td>52.052581</td>\n",
" </tr>\n",
" <tr>\n",
2024-08-26 19:51:11 +09:00
" <td>500</td>\n",
2024-09-25 08:52:30 +09:00
" <td>0.113800</td>\n",
" <td>0.046394</td>\n",
" <td>53.469249</td>\n",
" </tr>\n",
" <tr>\n",
" <td>600</td>\n",
" <td>0.084500</td>\n",
" <td>0.040225</td>\n",
" <td>53.980484</td>\n",
" </tr>\n",
" <tr>\n",
" <td>700</td>\n",
" <td>0.066900</td>\n",
" <td>0.026786</td>\n",
" <td>58.959618</td>\n",
" </tr>\n",
" <tr>\n",
" <td>800</td>\n",
" <td>0.053300</td>\n",
" <td>0.025612</td>\n",
" <td>52.672595</td>\n",
" </tr>\n",
" <tr>\n",
" <td>900</td>\n",
" <td>0.042600</td>\n",
" <td>0.019917</td>\n",
" <td>58.475230</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
2024-09-25 08:52:30 +09:00
" <td>0.038200</td>\n",
" <td>0.021234</td>\n",
" <td>52.335545</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1100</td>\n",
" <td>0.032500</td>\n",
" <td>0.021687</td>\n",
" <td>52.400191</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1200</td>\n",
" <td>0.030100</td>\n",
" <td>0.022106</td>\n",
" <td>59.836717</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1300</td>\n",
" <td>0.026800</td>\n",
" <td>0.020341</td>\n",
" <td>55.878989</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1400</td>\n",
" <td>0.023200</td>\n",
" <td>0.019192</td>\n",
" <td>53.356706</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" <tr>\n",
" <td>1500</td>\n",
2024-09-25 08:52:30 +09:00
" <td>0.022500</td>\n",
" <td>0.018187</td>\n",
" <td>59.718873</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1600</td>\n",
" <td>0.020900</td>\n",
" <td>0.017806</td>\n",
" <td>62.848480</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1700</td>\n",
" <td>0.017200</td>\n",
" <td>0.018625</td>\n",
" <td>62.796542</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1800</td>\n",
" <td>0.015500</td>\n",
" <td>0.020747</td>\n",
" <td>62.920445</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1900</td>\n",
" <td>0.013800</td>\n",
" <td>0.027109</td>\n",
" <td>68.566983</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" <tr>\n",
" <td>2000</td>\n",
2024-09-25 08:52:30 +09:00
" <td>0.013900</td>\n",
" <td>0.024757</td>\n",
" <td>65.792365</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2100</td>\n",
" <td>0.011600</td>\n",
" <td>0.021626</td>\n",
" <td>68.714757</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2200</td>\n",
" <td>0.011800</td>\n",
" <td>0.025541</td>\n",
" <td>73.793641</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2300</td>\n",
" <td>0.011000</td>\n",
" <td>0.017915</td>\n",
" <td>71.351766</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2400</td>\n",
" <td>0.010500</td>\n",
" <td>0.020459</td>\n",
" <td>76.285575</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" <tr>\n",
" <td>2500</td>\n",
2024-09-25 08:52:30 +09:00
" <td>0.009700</td>\n",
" <td>0.019714</td>\n",
" <td>78.722420</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2600</td>\n",
" <td>0.008700</td>\n",
" <td>0.026323</td>\n",
" <td>73.858894</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2700</td>\n",
" <td>0.008600</td>\n",
" <td>0.023967</td>\n",
" <td>78.752238</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2800</td>\n",
" <td>0.008500</td>\n",
" <td>0.025074</td>\n",
" <td>78.772012</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2900</td>\n",
" <td>0.008400</td>\n",
" <td>0.022061</td>\n",
" <td>83.261974</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" <tr>\n",
" <td>3000</td>\n",
2024-09-25 08:52:30 +09:00
" <td>0.008800</td>\n",
" <td>0.022081</td>\n",
" <td>80.992463</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" <tr>\n",
2024-09-25 08:52:30 +09:00
" <td>3100</td>\n",
" <td>0.007100</td>\n",
" <td>0.024494</td>\n",
" <td>81.058833</td>\n",
2024-08-26 19:51:11 +09:00
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
2024-09-25 08:52:30 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
2024-08-26 19:51:11 +09:00
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
2024-09-25 08:52:30 +09:00
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 113\u001b[0m\n\u001b[1;32m 97\u001b[0m early_stopping_callback \u001b[38;5;241m=\u001b[39m EarlyStoppingCallback(\n\u001b[1;32m 98\u001b[0m early_stopping_patience\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[1;32m 99\u001b[0m \n\u001b[1;32m 100\u001b[0m )\n\u001b[1;32m 102\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 103\u001b[0m model,\n\u001b[1;32m 104\u001b[0m args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 110\u001b[0m callbacks\u001b[38;5;241m=\u001b[39m[early_stopping_callback] \n\u001b[1;32m 111\u001b[0m )\n\u001b[0;32m--> 113\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m os\u001b[38;5;241m.\u001b[39m_exit(\u001b[38;5;241m0\u001b[39m)\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/accelerate/utils/memory.py:142\u001b[0m, in \u001b[0;36mfind_executable_batch_size.<locals>.decorator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo executable batch size found, reached zero.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_reduce_batch_size(e):\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2206\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2207\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2208\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2209\u001b[0m ):\n\u001b[1;32m 2210\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2211\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/transformers/trainer.py:3147\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3145\u001b[0m scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 3146\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3147\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/accelerate/accelerator.py:2013\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2011\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 2012\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2013\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 524\u001b[0m )\n\u001b[0;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/autograd/__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.10/site-packages/torch/autograd/graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
2024-08-26 19:51:11 +09:00
}
],
"source": [
2024-09-25 08:52:30 +09:00
"from datasets import load_from_disk\n",
2024-08-26 19:51:11 +09:00
"import json\n",
2024-09-25 08:52:30 +09:00
"from transformers import AutoTokenizer\n",
"import os\n",
"from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback\n",
"import evaluate\n",
"import numpy as np\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"model_name = \"google/t5-efficient-tiny\"\n",
"# google/t5-efficient-tiny\n",
"# google/t5-efficient-mini\n",
"# t5-small\n",
"# t5-base\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"train_epochs = 80\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"with open(\"mode.json\", \"r\") as json_file:\n",
" mode_dict = json.load(json_file)\n",
"\n",
"mode_dict.update({\"model\": model_name, \"train_epochs\": train_epochs})\n",
"fold_group = mode_dict.get(\"fold_group\")\n",
"\n",
"with open(\"mode.json\", \"w\") as json_file:\n",
" json.dump(mode_dict, json_file)\n",
"\n",
"mode = mode_dict.get(\"mode\", \"default_value\")\n",
"file_path = f'combined_data/{mode}/{fold_group}'\n",
"split_datasets = load_from_disk(file_path)\n",
2024-08-26 19:51:11 +09:00
"\n",
2024-09-25 08:52:30 +09:00
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \n",
" \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \n",
" \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \n",
" \"<UNIT_START>\", \"<UNIT_END>\"]\n",
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
"\n",
"max_length = 64\n",
"\n",
"def preprocess_function(examples):\n",
" inputs = [ex[\"input\"] for ex in examples['translation']]\n",
" targets = [ex[\"thing_property\"] for ex in examples['translation']]\n",
" return tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n",
"\n",
"tokenized_datasets = split_datasets.map(\n",
" preprocess_function,\n",
" batched=True,\n",
" remove_columns=split_datasets[\"train\"].column_names,\n",
")\n",
"\n",
"\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
"metric = evaluate.load(\"sacrebleu\")\n",
2024-08-26 19:51:11 +09:00
"\n",
"def compute_metrics(eval_preds):\n",
" preds, labels = eval_preds\n",
" if isinstance(preds, tuple):\n",
" preds = preds[0]\n",
"\n",
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
"\n",
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
" return {\"bleu\": result[\"score\"]}\n",
"\n",
"os.environ['NCCL_P2P_DISABLE'] = '1'\n",
"os.environ['NCCL_IB_DISABLE'] = '1'\n",
"\n",
"args = Seq2SeqTrainingArguments(\n",
" f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\",\n",
2024-09-25 08:52:30 +09:00
" save_strategy=\"steps\",\n",
" learning_rate=1e-3,\n",
2024-08-26 19:51:11 +09:00
" per_device_train_batch_size=32,\n",
" per_device_eval_batch_size=64,\n",
" auto_find_batch_size=True,\n",
" ddp_find_unused_parameters=False,\n",
" weight_decay=0.01,\n",
" save_total_limit=1,\n",
" num_train_epochs=train_epochs,\n",
" predict_with_generate=True,\n",
" bf16=True,\n",
" push_to_hub=False,\n",
2024-09-25 08:52:30 +09:00
" evaluation_strategy=\"steps\",\n",
" eval_steps=100,\n",
" save_steps=100, \n",
" logging_steps=100, \n",
" load_best_model_at_end=True, \n",
" metric_for_best_model=\"bleu\",\n",
" lr_scheduler_type=\"linear\",\n",
" warmup_steps=100,\n",
2024-08-26 19:51:11 +09:00
")\n",
"\n",
2024-09-25 08:52:30 +09:00
"# Define the EarlyStoppingCallback\n",
"early_stopping_callback = EarlyStoppingCallback(\n",
" early_stopping_patience=5,\n",
"\n",
")\n",
2024-08-26 19:51:11 +09:00
"\n",
"trainer = Seq2SeqTrainer(\n",
" model,\n",
" args,\n",
" train_dataset=tokenized_datasets[\"train\"],\n",
" eval_dataset=tokenized_datasets[\"validation\"],\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics,\n",
2024-09-25 08:52:30 +09:00
" callbacks=[early_stopping_callback] \n",
2024-08-26 19:51:11 +09:00
")\n",
"\n",
2024-09-25 08:52:30 +09:00
"trainer.train()\n",
"os._exit(0)\n"
2024-08-26 19:51:11 +09:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}