{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# t5 training for combined concatenated outputs (thing + property) \n", "\n", "refer to `t5_train_tp.py` and `guide_for_tp.md` for faster training workflow" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "90f850a9e8324109808e45e40f0eea47", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/6260 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "34e221d3425d414a9fb749a3ee28ad81", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/12969 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c5504c54cba4520aa34d5a6a078a31d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/2087 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/hwang/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Bleu | \n", "
---|---|---|---|
200 | \n", "2.654300 | \n", "0.112380 | \n", "26.397731 | \n", "
400 | \n", "0.106600 | \n", "0.035335 | \n", "87.137364 | \n", "
600 | \n", "0.044600 | \n", "0.022964 | \n", "89.884682 | \n", "
800 | \n", "0.026300 | \n", "0.018220 | \n", "86.274312 | \n", "
1000 | \n", "0.017300 | \n", "0.016252 | \n", "86.389477 | \n", "
1200 | \n", "0.012400 | \n", "0.015651 | \n", "94.416285 | \n", "
1400 | \n", "0.011500 | \n", "0.014833 | \n", "91.596509 | \n", "
1600 | \n", "0.008800 | \n", "0.015168 | \n", "91.629519 | \n", "
1800 | \n", "0.006900 | \n", "0.015042 | \n", "95.375351 | \n", "
"
],
"text/plain": [
"