Feat: jax implementation of t5 training and prediction

This commit is contained in:
Richard Wong 2024-09-11 08:17:02 +09:00
commit f523560141
6 changed files with 3011 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*.ipynb
t5_*/
exports/

357
requirements.yaml Normal file
View File

@ -0,0 +1,357 @@
name: jax
channels:
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- _sysroot_linux-64_curr_repodata_hack=3=h69a702a_16
- absl-py=2.1.0=pyhd8ed1ab_0
- aiohappyeyeballs=2.4.0=pyhd8ed1ab_0
- aiohttp=3.10.5=py311h61187de_0
- aiosignal=1.3.1=pyhd8ed1ab_0
- alsa-lib=1.2.12=h4ab18f5_0
- aom=3.9.1=hac33072_0
- arrow=1.3.0=pyhd8ed1ab_0
- asttokens=2.4.1=pyhd8ed1ab_0
- attrs=24.2.0=pyh71513ae_0
- aws-c-auth=0.7.29=h03582ad_1
- aws-c-cal=0.7.4=hfd43aa1_1
- aws-c-common=0.9.28=hb9d3cd8_0
- aws-c-compression=0.2.19=h756ea98_1
- aws-c-event-stream=0.4.3=h235a6dd_1
- aws-c-http=0.8.8=h5e77a74_2
- aws-c-io=0.14.18=hc2627b9_9
- aws-c-mqtt=0.10.4=h01636a3_19
- aws-c-s3=0.6.5=h191b246_2
- aws-c-sdkutils=0.1.19=h756ea98_3
- aws-checksums=0.1.18=h756ea98_11
- aws-crt-cpp=0.28.2=h29c84ef_4
- aws-sdk-cpp=1.11.379=h5a9005d_9
- azure-core-cpp=1.13.0=h935415a_0
- azure-identity-cpp=1.8.0=hd126650_2
- azure-storage-blobs-cpp=12.12.0=hd2e3451_0
- azure-storage-common-cpp=12.7.0=h10ac4d7_1
- azure-storage-files-datalake-cpp=12.11.0=h325d260_1
- binaryornot=0.4.4=py_1
- binutils_impl_linux-64=2.40=ha1999f0_7
- binutils_linux-64=2.40=hb3c18ed_2
- blosc=1.21.6=hef167b5_0
- brotli=1.1.0=hb9d3cd8_2
- brotli-bin=1.1.0=hb9d3cd8_2
- brotli-python=1.1.0=py311hfdbb021_2
- bzip2=1.0.8=h4bc722e_7
- c-ares=1.33.1=heb4867d_0
- ca-certificates=2024.8.30=hbcca054_0
- cairo=1.18.0=hebfffa5_3
- certifi=2024.8.30=pyhd8ed1ab_0
- cffi=1.17.1=py311hf29c0ef_0
- chardet=5.2.0=py311h38be061_2
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- chex=0.1.86=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- comm=0.2.2=pyhd8ed1ab_0
- contourpy=1.3.0=py311hd18a35c_1
- cookiecutter=2.6.0=pyhca7485f_0
- cuda-cccl_linux-64=12.6.37=ha770c72_0
- cuda-crt-dev_linux-64=12.6.68=ha770c72_0
- cuda-crt-tools=12.6.68=ha770c72_0
- cuda-cudart=12.6.68=h5888daf_0
- cuda-cudart-dev=12.6.68=h5888daf_0
- cuda-cudart-dev_linux-64=12.6.68=h3f2d84a_0
- cuda-cudart-static=12.6.68=h5888daf_0
- cuda-cudart-static_linux-64=12.6.68=h3f2d84a_0
- cuda-cudart_linux-64=12.6.68=h3f2d84a_0
- cuda-cupti=12.6.68=h5888daf_0
- cuda-driver-dev_linux-64=12.6.68=h3f2d84a_0
- cuda-nvcc=12.6.68=hcdd1206_0
- cuda-nvcc-dev_linux-64=12.6.68=he91c749_0
- cuda-nvcc-impl=12.6.68=h85509e4_0
- cuda-nvcc-tools=12.6.68=he02047a_0
- cuda-nvcc_linux-64=12.6.68=h8a487aa_0
- cuda-nvrtc=12.6.68=h5888daf_0
- cuda-nvtx=12.6.68=h5888daf_0
- cuda-nvvm-dev_linux-64=12.6.68=ha770c72_0
- cuda-nvvm-impl=12.6.68=he02047a_0
- cuda-nvvm-tools=12.6.68=he02047a_0
- cuda-version=12.6=h7480c83_3
- cudnn=9.2.1.18=hbc370b7_0
- cycler=0.12.1=pyhd8ed1ab_0
- datasets=2.21.0=pyhd8ed1ab_0
- dav1d=1.2.1=hd590300_0
- dbus=1.13.6=h5008d03_3
- debugpy=1.8.5=py311hfdbb021_1
- decorator=5.1.1=pyhd8ed1ab_0
- dill=0.3.8=pyhd8ed1ab_0
- double-conversion=3.3.0=h59595ed_0
- etils=1.9.4=pyhd8ed1ab_0
- evaluate=0.4.1=pyhd8ed1ab_0
- exceptiongroup=1.2.2=pyhd8ed1ab_0
- executing=2.1.0=pyhd8ed1ab_0
- expat=2.6.3=h5888daf_0
- filelock=3.16.0=pyhd8ed1ab_0
- flax=0.9.0=pyhd8ed1ab_0
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=h77eed37_2
- fontconfig=2.14.2=h14ed4e7_0
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- fonttools=4.53.1=py311h9ecbd09_1
- freetype=2.12.1=h267a509_2
- frozenlist=1.4.1=py311h9ecbd09_1
- fsspec=2024.5.0=pyhff2d567_0
- gcc_impl_linux-64=13.3.0=hfea6d02_1
- gcc_linux-64=13.3.0=hc28eda2_2
- gflags=2.2.2=he1b5a44_1004
- glog=0.7.1=hbabe93e_0
- graphite2=1.3.13=h59595ed_1003
- gxx_impl_linux-64=13.3.0=hdbfa832_1
- gxx_linux-64=13.3.0=h6834431_2
- h2=4.1.0=pyhd8ed1ab_0
- harfbuzz=9.0.0=hda332d3_1
- hpack=4.0.0=pyh9f0ad1d_0
- huggingface_hub=0.24.6=pyhd8ed1ab_0
- hyperframe=6.0.1=pyhd8ed1ab_0
- icu=75.1=he02047a_0
- idna=3.8=pyhd8ed1ab_0
- importlib-metadata=8.4.0=pyha770c72_0
- importlib_metadata=8.4.0=hd8ed1ab_0
- importlib_resources=6.4.5=pyhd8ed1ab_0
- ipykernel=6.29.5=pyh3099207_0
- ipython=8.27.0=pyh707e725_0
- jax=0.4.31=pyhd8ed1ab_1
- jaxlib=0.4.31=cuda120py311hd88f13b_201
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.4=pyhd8ed1ab_0
- joblib=1.4.2=pyhd8ed1ab_0
- jsonschema=4.23.0=pyhd8ed1ab_0
- jsonschema-specifications=2023.12.1=pyhd8ed1ab_0
- jupyter_client=8.6.2=pyhd8ed1ab_0
- jupyter_core=5.7.2=py311h38be061_0
- jupytext=1.16.4=pyh80e38bb_0
- kernel-headers_linux-64=3.10.0=h4a8ded7_16
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.7=py311hd18a35c_0
- krb5=1.21.3=h659f571_0
- lcms2=2.16=hb7c19ff_0
- ld_impl_linux-64=2.40=hf3520f5_7
- lerc=4.0.0=h27087fc_0
- libabseil=20240116.2=cxx17_he02047a_1
- libarrow=17.0.0=h8d2e343_13_cpu
- libarrow-acero=17.0.0=h5888daf_13_cpu
- libarrow-dataset=17.0.0=h5888daf_13_cpu
- libarrow-substrait=17.0.0=hf54134d_13_cpu
- libavif16=1.1.1=h104a339_1
- libblas=3.9.0=23_linux64_openblas
- libbrotlicommon=1.1.0=hb9d3cd8_2
- libbrotlidec=1.1.0=hb9d3cd8_2
- libbrotlienc=1.1.0=hb9d3cd8_2
- libcblas=3.9.0=23_linux64_openblas
- libclang-cpp18.1=18.1.8=default_hf981a13_4
- libclang13=18.1.8=default_h9def88c_4
- libcrc32c=1.1.2=h9c3ff4c_0
- libcublas=12.6.1.4=h5888daf_0
- libcufft=11.2.6.59=h5888daf_0
- libcups=2.3.3=h4637d8d_4
- libcurand=10.3.7.68=h5888daf_0
- libcurl=8.9.1=hdb1bdb2_0
- libcusolver=11.6.4.69=h5888daf_0
- libcusparse=12.5.3.3=h5888daf_0
- libdeflate=1.21=h4bc722e_0
- libdrm=2.4.123=hb9d3cd8_0
- libedit=3.1.20191231=he28a2e2_2
- libegl=1.7.0=ha4b6fd6_0
- libev=4.33=hd590300_2
- libevent=2.1.12=hf998b51_1
- libexpat=2.6.3=h5888daf_0
- libffi=3.4.2=h7f98852_5
- libgcc=14.1.0=h77fa898_1
- libgcc-devel_linux-64=13.3.0=h84ea5a7_101
- libgcc-ng=14.1.0=h69a702a_1
- libgfortran=14.1.0=h69a702a_1
- libgfortran-ng=14.1.0=h69a702a_1
- libgfortran5=14.1.0=hc5f4f2c_1
- libgl=1.7.0=ha4b6fd6_0
- libglib=2.80.3=h315aac3_2
- libglvnd=1.7.0=ha4b6fd6_0
- libglx=1.7.0=ha4b6fd6_0
- libgomp=14.1.0=h77fa898_1
- libgoogle-cloud=2.28.0=h26d7fe4_0
- libgoogle-cloud-storage=2.28.0=ha262f82_0
- libgrpc=1.62.2=h15f2491_0
- libiconv=1.17=hd590300_2
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=23_linux64_openblas
- libllvm18=18.1.8=h8b73ec9_2
- libnghttp2=1.58.0=h47da74e_1
- libnsl=2.0.1=hd590300_0
- libnvjitlink=12.6.68=h5888daf_0
- libopenblas=0.3.27=pthreads_hac2b453_1
- libparquet=17.0.0=h39682fd_13_cpu
- libpciaccess=0.18=hd590300_0
- libpng=1.6.43=h2797004_0
- libpq=16.4=h2d7952a_1
- libprotobuf=4.25.3=h08a7969_0
- libre2-11=2023.09.01=h5a48ba9_2
- libsanitizer=13.3.0=heb74ff8_1
- libsodium=1.0.20=h4ab18f5_0
- libsqlite=3.46.1=hadc24fc_0
- libssh2=1.11.0=h0841786_0
- libstdcxx=14.1.0=hc0a3c3a_1
- libstdcxx-devel_linux-64=13.3.0=h84ea5a7_101
- libstdcxx-ng=14.1.0=h4852527_1
- libthrift=0.20.0=h0e7cc3e_1
- libtiff=4.6.0=h46a8edc_4
- libutf8proc=2.8.0=h166bdaf_0
- libuuid=2.38.1=h0b41bf4_0
- libwebp-base=1.4.0=hd590300_0
- libxcb=1.16=hb9d3cd8_1
- libxcrypt=4.4.36=hd590300_1
- libxkbcommon=1.7.0=h2c5496b_1
- libxml2=2.12.7=he7c6b58_4
- libxslt=1.1.39=h76b75d6_0
- libzlib=1.3.1=h4ab18f5_1
- lxml=5.3.0=py311hcfaa980_1
- lz4-c=1.9.4=hcb278e6_0
- markdown-it-py=3.0.0=pyhd8ed1ab_0
- markupsafe=2.1.5=py311h9ecbd09_1
- matplotlib=3.9.2=py311h38be061_0
- matplotlib-base=3.9.2=py311h74b4f7c_0
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- mdit-py-plugins=0.4.1=pyhd8ed1ab_0
- mdurl=0.1.2=pyhd8ed1ab_0
- ml_dtypes=0.4.0=py311h7db5c69_2
- msgpack-python=1.0.8=py311hd18a35c_1
- multidict=6.1.0=py311h9ecbd09_0
- multiprocess=0.70.16=py311h9ecbd09_1
- munkres=1.1.4=pyh9f0ad1d_0
- mysql-common=9.0.1=h70512c7_0
- mysql-libs=9.0.1=ha479ceb_0
- nbformat=5.10.4=pyhd8ed1ab_0
- nccl=2.22.3.1=hbc370b7_1
- ncurses=6.5=he02047a_1
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- nltk=3.9.1=pyhd8ed1ab_0
- numpy=1.26.4=py311h64a7726_0
- openjpeg=2.5.2=h488ebb8_0
- openssl=3.3.2=hb9d3cd8_0
- opt-einsum=3.3.0=hd8ed1ab_2
- opt_einsum=3.3.0=pyhc1e730c_2
- optax=0.2.2=pyhd8ed1ab_1
- orbax-checkpoint=0.4.4=pyhd8ed1ab_0
- orc=2.0.2=h669347b_0
- packaging=24.1=pyhd8ed1ab_0
- pandas=2.2.2=py311h14de704_1
- parso=0.8.4=pyhd8ed1ab_0
- pcre2=10.44=hba22ea6_2
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pillow=10.4.0=py311h82a398c_0
- pip=24.2=pyh8b19718_1
- pixman=0.43.2=h59595ed_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
- platformdirs=4.3.2=pyhd8ed1ab_0
- portalocker=2.10.1=py311h38be061_0
- prompt-toolkit=3.0.47=pyha770c72_0
- protobuf=4.25.3=py311hbffca5d_1
- psutil=6.0.0=py311h9ecbd09_1
- pthread-stubs=0.4=h36c2ea0_1001
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.3=pyhd8ed1ab_0
- pyarrow=17.0.0=py311hbd00459_1
- pyarrow-core=17.0.0=py311h4510849_1_cpu
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
- pybind11-abi=4=hd8ed1ab_3
- pycparser=2.22=pyhd8ed1ab_0
- pygments=2.18.0=pyhd8ed1ab_0
- pyparsing=3.1.4=pyhd8ed1ab_0
- pyside6=6.7.2=py311hba19f1e_2
- pysocks=1.7.1=pyha2e5f31_6
- python=3.11.9=hb806964_0_cpython
- python-dateutil=2.9.0=pyhd8ed1ab_0
- python-fastjsonschema=2.20.0=pyhd8ed1ab_0
- python-slugify=8.0.4=pyhd8ed1ab_0
- python-tzdata=2024.1=pyhd8ed1ab_0
- python-xxhash=3.5.0=py311h9ecbd09_1
- python_abi=3.11=5_cp311
- pytz=2024.1=pyhd8ed1ab_0
- pyyaml=6.0.2=py311h9ecbd09_1
- pyzmq=26.2.0=py311h7deb3e3_2
- qhull=2020.2=h434a139_5
- qt6-main=6.7.2=hb12f9c5_5
- rav1e=0.6.6=he8a937b_2
- re2=2023.09.01=h7f4b329_2
- readline=8.2=h8228510_1
- referencing=0.35.1=pyhd8ed1ab_0
- regex=2024.7.24=py311h9ecbd09_1
- requests=2.32.3=pyhd8ed1ab_0
- responses=0.18.0=pyhd8ed1ab_0
- rich=13.7.1=pyhd8ed1ab_0
- rpds-py=0.20.0=py311h9e33e62_1
- s2n=1.5.2=h7b32b05_0
- sacrebleu=2.4.1=pyhd8ed1ab_1
- safetensors=0.4.5=py311h9e33e62_0
- scipy=1.14.1=py311he1f765f_0
- setuptools=73.0.1=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- snappy=1.2.1=ha2e4443_0
- stack_data=0.6.2=pyhd8ed1ab_0
- svt-av1=2.2.1=h5888daf_0
- sysroot_linux-64=2.17=h4a8ded7_16
- tabulate=0.9.0=pyhd8ed1ab_1
- tensorstore=0.1.62=py311he109767_1
- text-unidecode=1.3=pyhd8ed1ab_1
- tk=8.6.13=noxft_h4845f30_101
- tokenizers=0.19.1=py311h6640629_0
- tomli=2.0.1=pyhd8ed1ab_0
- toolz=0.12.1=pyhd8ed1ab_0
- tornado=6.4.1=py311h9ecbd09_1
- tqdm=4.66.5=pyhd8ed1ab_0
- traitlets=5.14.3=pyhd8ed1ab_0
- transformers=4.41.2=pyhd8ed1ab_0
- types-python-dateutil=2.9.0.20240906=pyhd8ed1ab_0
- typing=3.10.0.0=pyhd8ed1ab_1
- typing-extensions=4.12.2=hd8ed1ab_0
- typing_extensions=4.12.2=pyha770c72_0
- tzdata=2024a=h8827d51_1
- urllib3=2.2.2=pyhd8ed1ab_1
- wayland=1.23.1=h3e06ad9_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.44.0=pyhd8ed1ab_0
- xcb-util=0.4.1=hb711507_2
- xcb-util-cursor=0.1.4=h4ab18f5_2
- xcb-util-image=0.4.0=hb711507_2
- xcb-util-keysyms=0.4.1=hb711507_0
- xcb-util-renderutil=0.3.10=hb711507_0
- xcb-util-wm=0.4.2=hb711507_0
- xkeyboard-config=2.42=h4ab18f5_0
- xorg-fixesproto=5.0=h7f98852_1002
- xorg-inputproto=2.3.2=h7f98852_1002
- xorg-kbproto=1.0.7=h7f98852_1002
- xorg-libice=1.1.1=hd590300_0
- xorg-libsm=1.2.4=h7391055_0
- xorg-libx11=1.8.9=hb711507_1
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xorg-libxext=1.3.4=h0b41bf4_2
- xorg-libxfixes=5.0.3=h7f98852_1004
- xorg-libxi=1.7.10=h4bc722e_1
- xorg-libxrender=0.9.11=hd590300_0
- xorg-libxtst=1.2.5=h4bc722e_0
- xorg-libxxf86vm=1.1.5=h4bc722e_1
- xorg-recordproto=1.14.2=h7f98852_1002
- xorg-renderproto=0.11.1=h7f98852_1002
- xorg-xextproto=7.3.0=h0b41bf4_1003
- xorg-xproto=7.0.31=h7f98852_1007
- xxhash=0.8.2=hd590300_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- yarl=1.11.1=py311h9ecbd09_0
- zeromq=4.3.5=ha4adb4c_5
- zipp=3.20.1=pyhd8ed1ab_0
- zlib=1.3.1=h4ab18f5_1
- zstandard=0.23.0=py311hbc35293_1
- zstd=1.5.6=ha6fb4c9_0

620
t5_jax.py Normal file
View File

@ -0,0 +1,620 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: jax
# language: python
# name: python3
# ---
# %% [markdown]
# # T5 implementation using jax
# %% [markdown]
# ## import
# %% [raw]
# import json
# import logging
# import math
# import os
# import sys
# import time
# from dataclasses import asdict, dataclass, field
# from enum import Enum
# from functools import partial
# from pathlib import Path
# from typing import Callable, Optional
#
# import datasets
# import evaluate
# import jax
# import jax.numpy as jnp
# import nltk # Here to have a nice missing dependency error message early on
# import numpy as np
# import optax
# from datasets import Dataset, load_dataset
# from filelock import FileLock
# from flax import jax_utils, traverse_util
# from flax.jax_utils import pad_shard_unpad, unreplicate
# from flax.training import train_state
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
# from tqdm import tqdm
#
# import transformers
# from transformers import (
# CONFIG_MAPPING,
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
# AutoConfig,
# AutoTokenizer,
# FlaxAutoModelForSeq2SeqLM,
# HfArgumentParser,
# is_tensorboard_available,
# )
# from transformers.utils import is_offline_mode, send_example_telemetry
#
#
# logger = logging.getLogger(__name__)
#
# try:
# nltk.data.find("tokenizers/punkt")
# except (LookupError, OSError):
# if is_offline_mode():
# raise LookupError(
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
# )
# with FileLock(".lock") as lock:
# nltk.download("punkt", quiet=True)
#
#
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# %%
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset, load_dataset
import evaluate
from tqdm import tqdm
from datasets import load_from_disk
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=True '
'--xla_gpu_triton_gemm_any=True '
)
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
# %%
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# nltk.download('punkt')
try:
nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
if is_offline_mode():
raise LookupError(
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
)
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
# %% [markdown]
# ## Prepare datasets
# %%
# load model
model_name_or_path = "t5-small" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# %%
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# %%
# Path to saved combined_dataset
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
save_path = 't5_80_1'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
input = example['input']
target = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
tokenized_datasets = split_datasets.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=split_datasets["train"].column_names,
)
# %%
tokenized_datasets
# %%
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
# %%
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# Now we have model inputs in terms of the variable tokenized_datasets
# %%
# metric
metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
# def compute_metrics(preds, labels):
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#
# # Some simple post-processing
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
#
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
# result = {k: round(v * 100, 4) for k, v in result.items()}
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
# result["gen_len"] = np.mean(prediction_lens)
# return result
def compute_metrics(preds, labels):
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100s in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
return {"bleu": result["score"]}
# %% [markdown]
# # Model
# %%
# Store some constant
seed = 117
num_epochs = 80
batch_size = 96
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.0
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = None
predict_with_generate = True
# %%
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# optimization functions
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# Cast parameters to bfloat16 if desired
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss}
return metrics
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
# %%
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
# epochs = range(num_epochs)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_labels = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
labels = batch["labels"]
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# generation
if predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(labels)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute metrics
rouge_desc = ""
if predict_with_generate:
rouge_metrics = compute_metrics(eval_preds, eval_labels)
eval_metrics.update(rouge_metrics)
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
# Print metrics and update progress bar
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
epochs.write(desc)
epochs.desc = desc
# Save metrics
# if has_tensorboard and jax.process_index() == 0:
# cur_step = epoch * (len(train_dataset) // train_batch_size)
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(output_dir, params=params)
tokenizer.save_pretrained(output_dir)
# %% [markdown]
# #

386
t5_jax_prediction.py Normal file
View File

@ -0,0 +1,386 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # prediction code
# ## import and process test data
# %%
# import libraries
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset, load_dataset
import evaluate
from tqdm import tqdm
from datasets import load_from_disk
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
# 'input': f"<DESC>{row['tag_description']}<DESC>",
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
'answer': f"{row['thing']} {row['property']}",
'answer_thing': row['thing'],
'answer_property': row['property'],
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
test_dataset = Dataset.from_list(process_df(df))
# %% [markdown]
# ## Load model for attributes
# %%
# load model
model_name_or_path = "t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# %% [markdown]
# ## Tokenizer
# %%
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 86
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
input = example['input']
target = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=test_dataset.column_names,
)
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# # Model Training
# %%
seed = 117
num_epochs = 80
batch_size = 96
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(test_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.0
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = None
predict_with_generate = True
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(test_dataset),
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# reload model to prevent leakage of variables
# load model
model_name_or_path = "t5_80_1" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# Cast parameters to bfloat16 if desired
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
pred_metrics = []
pred_generations = []
pred_labels = []
rng, input_rng = jax.random.split(rng)
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
print("***** Running training *****")
print(f" Num examples = {len(test_dataset)}")
print(f" Num steps = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
# generation
if predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(labels)
# Print metrics
# desc = f"Predict Loss: {pred_metrics['loss']})"
# print(desc)
# %%
# save predictions to parquet
# decode prediction labels
def decode_preds(preds):
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_preds = [pred for pred in decoded_preds]
return decoded_preds
# Convert the list to a Pandas DataFrame
df = pd.DataFrame(decode_preds(pred_labels))
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
# %%

624
t5_jax_retrieval.py Normal file
View File

@ -0,0 +1,624 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# kernelspec:
# display_name: jax
# language: python
# name: python3
# ---
# %% [markdown]
# # T5 implementation using jax
# %% [markdown]
# ## import
# %% [raw]
# import json
# import logging
# import math
# import os
# import sys
# import time
# from dataclasses import asdict, dataclass, field
# from enum import Enum
# from functools import partial
# from pathlib import Path
# from typing import Callable, Optional
#
# import datasets
# import evaluate
# import jax
# import jax.numpy as jnp
# import nltk # Here to have a nice missing dependency error message early on
# import numpy as np
# import optax
# from datasets import Dataset, load_dataset
# from filelock import FileLock
# from flax import jax_utils, traverse_util
# from flax.jax_utils import pad_shard_unpad, unreplicate
# from flax.training import train_state
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
# from tqdm import tqdm
#
# import transformers
# from transformers import (
# CONFIG_MAPPING,
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
# AutoConfig,
# AutoTokenizer,
# FlaxAutoModelForSeq2SeqLM,
# HfArgumentParser,
# is_tensorboard_available,
# )
# from transformers.utils import is_offline_mode, send_example_telemetry
#
#
# logger = logging.getLogger(__name__)
#
# try:
# nltk.data.find("tokenizers/punkt")
# except (LookupError, OSError):
# if is_offline_mode():
# raise LookupError(
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
# )
# with FileLock(".lock") as lock:
# nltk.download("punkt", quiet=True)
#
#
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# %%
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
from typing import Callable, Optional
import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high")
jax.config.update("jax_enable_x64", False)
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset, load_dataset
import evaluate
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
import time
# %%
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=True '
'--xla_gpu_triton_gemm_any=True '
)
os.environ.update({
"CUDA_VISIBLE_DEVICES": "0, 1, 2, 3",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
# %%
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# %%
# nltk.download('punkt')
try:
nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
if is_offline_mode():
raise LookupError(
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
)
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
# %% [markdown]
# ## Prepare datasets
# %%
# load model
model_name_or_path = "t5-small" # Replace with your specific model name
# Load configuration
config = AutoConfig.from_pretrained(model_name_or_path)
# Load model
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path
)
# %%
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# %%
from tqdm import tqdm
from datasets import load_from_disk
# Path to saved combined_dataset
file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5'
save_path = 't5_80_1_retrieval'
# file_path = 'combined_data'
split_datasets = load_from_disk(file_path)
# prepare tokenizer
from transformers import T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
# Define additional special tokens
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
# Define additional special tokens
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>",
"<CONTEXT>", "<EXAMPLE>", "<INPUT>", "<OUTPUT>"]
# Add the additional special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
max_length = 300
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
# given a dataset entry, run it through the tokenizer
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(example):
input = example['input']
target = example['output']
# text_target sets the corresponding label to inputs
# there is no need to create a separate 'labels'
model_inputs = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
labels = tokenizer(
input,
text_target=target,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
decoder_input_ids = shift_tokens_right_fn(
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
)
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
return model_inputs
# map maps function to each "row" in the dataset
# aka the data in the immediate nesting
tokenized_datasets = split_datasets.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=split_datasets["train"].column_names,
)
# %%
tokenized_datasets
# %%
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
# %%
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
# %% [markdown]
# Now we have model inputs in terms of the variable tokenized_datasets
# %%
# metric
metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
# def compute_metrics(preds, labels):
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
#
# # Some simple post-processing
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
#
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
# result = {k: round(v * 100, 4) for k, v in result.items()}
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
# result["gen_len"] = np.mean(prediction_lens)
# return result
def compute_metrics(preds, labels):
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100s in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
return {"bleu": result["score"]}
# %% [markdown]
# # Model
# %%
# Store some constant
seed = 117
num_epochs = 80
batch_size = 36
num_train_epochs = num_epochs
per_device_train_batch_size = batch_size
train_batch_size = per_device_train_batch_size * jax.device_count()
per_device_eval_batch_size = batch_size
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
warmup_steps = 0
learning_rate = 5e-5
weight_decay = 0.0
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_epsilon = 1e-8
label_smoothing_factor = 0.0
num_beams = 1
val_max_target_length = None
predict_with_generate = True
# %%
# Initialize our training
rng = jax.random.PRNGKey(seed)
rng, dropout_rng = jax.random.split(rng)
# %%
# optimization functions
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
num_train_epochs,
warmup_steps,
learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=adam_beta1,
b2=adam_beta2,
eps=adam_epsilon,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
# %%
# Training functions
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# Ensure model.params is properly initialized (this is just an example)
# Normally you would get this from a model initialization call with dummy input
params = model.params
# Cast parameters to bfloat16 if desired
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels
# compute gradients through computational graph
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss}
return metrics
# Define generation function
max_length = (
val_max_target_length if val_max_target_length is not None else model.config.max_length
)
num_beams = num_beams if num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
)
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
# %%
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num Epochs = {num_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
print(f" Total optimization steps = {total_train_steps}")
# %%
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
# epochs = range(num_epochs)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
# eval_metrics = []
# eval_preds = []
# eval_labels = []
# eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
# eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
# for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# # Model forward
# batch = next(eval_loader)
# labels = batch["labels"]
# metrics = pad_shard_unpad(p_eval_step, static_return=True)(
# state.params, batch, min_device_batch=per_device_eval_batch_size
# )
# eval_metrics.append(metrics)
# # generation
# if predict_with_generate:
# generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
# eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
# eval_labels.extend(labels)
# # normalize eval metrics
# eval_metrics = get_metrics(eval_metrics)
# eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute metrics
# rouge_desc = ""
# if predict_with_generate:
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
# eval_metrics.update(rouge_metrics)
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
# # Print metrics and update progress bar
# desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
# epochs.write(desc)
# epochs.desc = desc
# Save metrics
# if has_tensorboard and jax.process_index() == 0:
# cur_step = epoch * (len(train_dataset) // train_batch_size)
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
output_dir = save_path
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(output_dir, params=params)
tokenizer.save_pretrained(output_dir)
# %% [markdown]
# #

1021
t5_summarizer_flax.py Normal file

File diff suppressed because it is too large Load Diff