478 lines
35 KiB
Plaintext
478 lines
35 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# t5 training for combined concatenated outputs (thing + property) \n",
|
|
"\n",
|
|
"refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The mode has been set to: tn_td_unit\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d8d70681f4594917b7af4583a4237168",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Map: 0%| | 0/6125 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "106e0cefe50c40f0a83371693cf48cf7",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Map: 0%| | 0/14719 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "952f8ec73df0418490cb43beaaf5a7df",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Map: 0%| | 0/2042 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# import data and load dataset\n",
|
|
"from datasets import load_from_disk\n",
|
|
"import json\n",
|
|
"from transformers import AutoTokenizer\n",
|
|
"\n",
|
|
"model_name = \"t5-base\"\n",
|
|
"train_epochs = 80\n",
|
|
"\n",
|
|
"\n",
|
|
"# Read the mode from the JSON file\n",
|
|
"with open(\"mode.json\", \"r\") as json_file:\n",
|
|
" mode_dict = json.load(json_file)\n",
|
|
"\n",
|
|
"# Add the model key to the dictionary\n",
|
|
"mode_dict[\"model\"] = model_name\n",
|
|
"mode_dict[\"train_epochs\"] = train_epochs\n",
|
|
"\n",
|
|
"# Access the fold_group value\n",
|
|
"fold_group = mode_dict.get(\"fold_group\")\n",
|
|
"\n",
|
|
"# Save the updated dictionary back to the JSON file\n",
|
|
"with open(\"mode.json\", \"w\") as json_file:\n",
|
|
" json.dump(mode_dict, json_file)\n",
|
|
"\n",
|
|
"# Set the mode variable from the JSON content\n",
|
|
"mode = mode_dict.get(\"mode\", \"default_value\") # 'default_value' is a fallback if 'mode' is not found\n",
|
|
"\n",
|
|
"print(f\"The mode has been set to: {mode}\")\n",
|
|
"\n",
|
|
"# Path to saved combined_dataset\n",
|
|
"file_path = f'combined_data/{mode}/{fold_group}'\n",
|
|
"split_datasets = load_from_disk(file_path)\n",
|
|
"\n",
|
|
"\n",
|
|
" \n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
|
"# Define additional special tokens\n",
|
|
"# additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\"]\n",
|
|
"additional_special_tokens = [\"<THING_START>\", \"<THING_END>\", \"<PROPERTY_START>\", \"<PROPERTY_END>\", \"<TN_START>\", \"<TN_END>\", \"<TD_START>\", \"<TD_END>\", \"<MIN_START>\", \"<MIN_END>\", \"<MAX_START>\", \"<MAX_END>\", \"<UNIT_START>\", \"<UNIT_END>\"]\n",
|
|
"# Add the additional special tokens to the tokenizer\n",
|
|
"tokenizer.add_special_tokens({\"additional_special_tokens\": additional_special_tokens})\n",
|
|
"\n",
|
|
"max_length = 64\n",
|
|
"\n",
|
|
"def preprocess_function(examples):\n",
|
|
" inputs = [ex[\"input\"] for ex in examples['translation']]\n",
|
|
" targets = [ex[\"thing_property\"] for ex in examples['translation']]\n",
|
|
" # text_target sets the corresponding label to inputs\n",
|
|
" # there is no need to create a separate 'labels'\n",
|
|
" model_inputs = tokenizer(\n",
|
|
" inputs, text_target=targets, max_length=max_length, truncation=True\n",
|
|
" )\n",
|
|
" return model_inputs\n",
|
|
"\n",
|
|
"# map method maps preprocess_function to [train, valid, test] datasets of the datasetDict\n",
|
|
"tokenized_datasets = split_datasets.map(\n",
|
|
" preprocess_function,\n",
|
|
" batched=True,\n",
|
|
" remove_columns=split_datasets[\"train\"].column_names,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"\n",
|
|
" <div>\n",
|
|
" \n",
|
|
" <progress value='3840' max='3840' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
|
" [3840/3840 42:37, Epoch 80/80]\n",
|
|
" </div>\n",
|
|
" <table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: left;\">\n",
|
|
" <th>Step</th>\n",
|
|
" <th>Training Loss</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <td>500</td>\n",
|
|
" <td>2.812300</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1000</td>\n",
|
|
" <td>0.699300</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>1500</td>\n",
|
|
" <td>0.440900</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2000</td>\n",
|
|
" <td>0.332100</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>2500</td>\n",
|
|
" <td>0.276500</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3000</td>\n",
|
|
" <td>0.245900</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td>3500</td>\n",
|
|
" <td>0.229300</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table><p>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"TrainOutput(global_step=3840, training_loss=0.6754856963952383, metrics={'train_runtime': 2559.4201, 'train_samples_per_second': 191.45, 'train_steps_per_second': 1.5, 'total_flos': 3.156037495934976e+16, 'train_loss': 0.6754856963952383, 'epoch': 80.0})"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"\n",
|
|
"# we use the pre-trained t5-base model\n",
|
|
"from transformers import AutoModelForSeq2SeqLM\n",
|
|
"model_checkpoint = model_name\n",
|
|
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
|
|
"\n",
|
|
"# data collator\n",
|
|
"from transformers import DataCollatorForSeq2Seq\n",
|
|
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
|
"\n",
|
|
"# evaluation \n",
|
|
"import evaluate\n",
|
|
"metric = evaluate.load(\"sacrebleu\")\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"\n",
|
|
"def compute_metrics(eval_preds):\n",
|
|
" preds, labels = eval_preds\n",
|
|
" # In case the model returns more than the prediction logits\n",
|
|
" if isinstance(preds, tuple):\n",
|
|
" preds = preds[0]\n",
|
|
"\n",
|
|
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
|
"\n",
|
|
" # Replace -100s in the labels as we can't decode them\n",
|
|
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
|
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
|
"\n",
|
|
" # Some simple post-processing\n",
|
|
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
|
|
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
|
|
"\n",
|
|
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
|
" return {\"bleu\": result[\"score\"]}\n",
|
|
"\n",
|
|
"from transformers import Seq2SeqTrainingArguments\n",
|
|
"\n",
|
|
"# load environment variables to disable GPU p2p mode for multi-gpu training without p2p mode\n",
|
|
"# not required for single-gpu training\n",
|
|
"import os\n",
|
|
"os.environ['NCCL_P2P_DISABLE'] = '1'\n",
|
|
"os.environ['NCCL_IB_DISABLE'] = '1'\n",
|
|
"\n",
|
|
"args = Seq2SeqTrainingArguments(\n",
|
|
" f\"train_{fold_group}_{model_name}_{mode}_{train_epochs}\",\n",
|
|
" evaluation_strategy=\"no\",\n",
|
|
" # logging_dir=\"tensorboard-log\",\n",
|
|
" # logging_strategy=\"epoch\",\n",
|
|
" save_strategy=\"epoch\",\n",
|
|
" learning_rate=2e-5,\n",
|
|
" per_device_train_batch_size=32,\n",
|
|
" per_device_eval_batch_size=64,\n",
|
|
" auto_find_batch_size=True,\n",
|
|
" ddp_find_unused_parameters=False,\n",
|
|
" weight_decay=0.01,\n",
|
|
" save_total_limit=1,\n",
|
|
" num_train_epochs=train_epochs,\n",
|
|
" predict_with_generate=True,\n",
|
|
" bf16=True,\n",
|
|
" push_to_hub=False,\n",
|
|
")\n",
|
|
"\n",
|
|
"from transformers import Seq2SeqTrainer\n",
|
|
"\n",
|
|
"trainer = Seq2SeqTrainer(\n",
|
|
" model,\n",
|
|
" args,\n",
|
|
" train_dataset=tokenized_datasets[\"train\"],\n",
|
|
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
|
" data_collator=data_collator,\n",
|
|
" tokenizer=tokenizer,\n",
|
|
" compute_metrics=compute_metrics,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Train the model\n",
|
|
"trainer.train()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "base",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|