From f523560141887ab5155ba5b91077ef4968244d7e Mon Sep 17 00:00:00 2001 From: Richard Wong Date: Wed, 11 Sep 2024 08:17:02 +0900 Subject: [PATCH] Feat: jax implementation of t5 training and prediction --- .gitignore | 3 + requirements.yaml | 357 ++++++++++++++ t5_jax.py | 620 +++++++++++++++++++++++++ t5_jax_prediction.py | 386 ++++++++++++++++ t5_jax_retrieval.py | 624 +++++++++++++++++++++++++ t5_summarizer_flax.py | 1021 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 3011 insertions(+) create mode 100644 .gitignore create mode 100644 requirements.yaml create mode 100644 t5_jax.py create mode 100644 t5_jax_prediction.py create mode 100644 t5_jax_retrieval.py create mode 100644 t5_summarizer_flax.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bd49e38 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.ipynb +t5_*/ +exports/ diff --git a/requirements.yaml b/requirements.yaml new file mode 100644 index 0000000..de725da --- /dev/null +++ b/requirements.yaml @@ -0,0 +1,357 @@ +name: jax +channels: +- conda-forge +dependencies: +- _libgcc_mutex=0.1=conda_forge +- _openmp_mutex=4.5=2_gnu +- _sysroot_linux-64_curr_repodata_hack=3=h69a702a_16 +- absl-py=2.1.0=pyhd8ed1ab_0 +- aiohappyeyeballs=2.4.0=pyhd8ed1ab_0 +- aiohttp=3.10.5=py311h61187de_0 +- aiosignal=1.3.1=pyhd8ed1ab_0 +- alsa-lib=1.2.12=h4ab18f5_0 +- aom=3.9.1=hac33072_0 +- arrow=1.3.0=pyhd8ed1ab_0 +- asttokens=2.4.1=pyhd8ed1ab_0 +- attrs=24.2.0=pyh71513ae_0 +- aws-c-auth=0.7.29=h03582ad_1 +- aws-c-cal=0.7.4=hfd43aa1_1 +- aws-c-common=0.9.28=hb9d3cd8_0 +- aws-c-compression=0.2.19=h756ea98_1 +- aws-c-event-stream=0.4.3=h235a6dd_1 +- aws-c-http=0.8.8=h5e77a74_2 +- aws-c-io=0.14.18=hc2627b9_9 +- aws-c-mqtt=0.10.4=h01636a3_19 +- aws-c-s3=0.6.5=h191b246_2 +- aws-c-sdkutils=0.1.19=h756ea98_3 +- aws-checksums=0.1.18=h756ea98_11 +- aws-crt-cpp=0.28.2=h29c84ef_4 +- aws-sdk-cpp=1.11.379=h5a9005d_9 +- azure-core-cpp=1.13.0=h935415a_0 +- azure-identity-cpp=1.8.0=hd126650_2 +- azure-storage-blobs-cpp=12.12.0=hd2e3451_0 +- azure-storage-common-cpp=12.7.0=h10ac4d7_1 +- azure-storage-files-datalake-cpp=12.11.0=h325d260_1 +- binaryornot=0.4.4=py_1 +- binutils_impl_linux-64=2.40=ha1999f0_7 +- binutils_linux-64=2.40=hb3c18ed_2 +- blosc=1.21.6=hef167b5_0 +- brotli=1.1.0=hb9d3cd8_2 +- brotli-bin=1.1.0=hb9d3cd8_2 +- brotli-python=1.1.0=py311hfdbb021_2 +- bzip2=1.0.8=h4bc722e_7 +- c-ares=1.33.1=heb4867d_0 +- ca-certificates=2024.8.30=hbcca054_0 +- cairo=1.18.0=hebfffa5_3 +- certifi=2024.8.30=pyhd8ed1ab_0 +- cffi=1.17.1=py311hf29c0ef_0 +- chardet=5.2.0=py311h38be061_2 +- charset-normalizer=3.3.2=pyhd8ed1ab_0 +- chex=0.1.86=pyhd8ed1ab_0 +- click=8.1.7=unix_pyh707e725_0 +- colorama=0.4.6=pyhd8ed1ab_0 +- comm=0.2.2=pyhd8ed1ab_0 +- contourpy=1.3.0=py311hd18a35c_1 +- cookiecutter=2.6.0=pyhca7485f_0 +- cuda-cccl_linux-64=12.6.37=ha770c72_0 +- cuda-crt-dev_linux-64=12.6.68=ha770c72_0 +- cuda-crt-tools=12.6.68=ha770c72_0 +- cuda-cudart=12.6.68=h5888daf_0 +- cuda-cudart-dev=12.6.68=h5888daf_0 +- cuda-cudart-dev_linux-64=12.6.68=h3f2d84a_0 +- cuda-cudart-static=12.6.68=h5888daf_0 +- cuda-cudart-static_linux-64=12.6.68=h3f2d84a_0 +- cuda-cudart_linux-64=12.6.68=h3f2d84a_0 +- cuda-cupti=12.6.68=h5888daf_0 +- cuda-driver-dev_linux-64=12.6.68=h3f2d84a_0 +- cuda-nvcc=12.6.68=hcdd1206_0 +- cuda-nvcc-dev_linux-64=12.6.68=he91c749_0 +- cuda-nvcc-impl=12.6.68=h85509e4_0 +- cuda-nvcc-tools=12.6.68=he02047a_0 +- cuda-nvcc_linux-64=12.6.68=h8a487aa_0 +- cuda-nvrtc=12.6.68=h5888daf_0 +- cuda-nvtx=12.6.68=h5888daf_0 +- cuda-nvvm-dev_linux-64=12.6.68=ha770c72_0 +- cuda-nvvm-impl=12.6.68=he02047a_0 +- cuda-nvvm-tools=12.6.68=he02047a_0 +- cuda-version=12.6=h7480c83_3 +- cudnn=9.2.1.18=hbc370b7_0 +- cycler=0.12.1=pyhd8ed1ab_0 +- datasets=2.21.0=pyhd8ed1ab_0 +- dav1d=1.2.1=hd590300_0 +- dbus=1.13.6=h5008d03_3 +- debugpy=1.8.5=py311hfdbb021_1 +- decorator=5.1.1=pyhd8ed1ab_0 +- dill=0.3.8=pyhd8ed1ab_0 +- double-conversion=3.3.0=h59595ed_0 +- etils=1.9.4=pyhd8ed1ab_0 +- evaluate=0.4.1=pyhd8ed1ab_0 +- exceptiongroup=1.2.2=pyhd8ed1ab_0 +- executing=2.1.0=pyhd8ed1ab_0 +- expat=2.6.3=h5888daf_0 +- filelock=3.16.0=pyhd8ed1ab_0 +- flax=0.9.0=pyhd8ed1ab_0 +- font-ttf-dejavu-sans-mono=2.37=hab24e00_0 +- font-ttf-inconsolata=3.000=h77eed37_0 +- font-ttf-source-code-pro=2.038=h77eed37_0 +- font-ttf-ubuntu=0.83=h77eed37_2 +- fontconfig=2.14.2=h14ed4e7_0 +- fonts-conda-ecosystem=1=0 +- fonts-conda-forge=1=0 +- fonttools=4.53.1=py311h9ecbd09_1 +- freetype=2.12.1=h267a509_2 +- frozenlist=1.4.1=py311h9ecbd09_1 +- fsspec=2024.5.0=pyhff2d567_0 +- gcc_impl_linux-64=13.3.0=hfea6d02_1 +- gcc_linux-64=13.3.0=hc28eda2_2 +- gflags=2.2.2=he1b5a44_1004 +- glog=0.7.1=hbabe93e_0 +- graphite2=1.3.13=h59595ed_1003 +- gxx_impl_linux-64=13.3.0=hdbfa832_1 +- gxx_linux-64=13.3.0=h6834431_2 +- h2=4.1.0=pyhd8ed1ab_0 +- harfbuzz=9.0.0=hda332d3_1 +- hpack=4.0.0=pyh9f0ad1d_0 +- huggingface_hub=0.24.6=pyhd8ed1ab_0 +- hyperframe=6.0.1=pyhd8ed1ab_0 +- icu=75.1=he02047a_0 +- idna=3.8=pyhd8ed1ab_0 +- importlib-metadata=8.4.0=pyha770c72_0 +- importlib_metadata=8.4.0=hd8ed1ab_0 +- importlib_resources=6.4.5=pyhd8ed1ab_0 +- ipykernel=6.29.5=pyh3099207_0 +- ipython=8.27.0=pyh707e725_0 +- jax=0.4.31=pyhd8ed1ab_1 +- jaxlib=0.4.31=cuda120py311hd88f13b_201 +- jedi=0.19.1=pyhd8ed1ab_0 +- jinja2=3.1.4=pyhd8ed1ab_0 +- joblib=1.4.2=pyhd8ed1ab_0 +- jsonschema=4.23.0=pyhd8ed1ab_0 +- jsonschema-specifications=2023.12.1=pyhd8ed1ab_0 +- jupyter_client=8.6.2=pyhd8ed1ab_0 +- jupyter_core=5.7.2=py311h38be061_0 +- jupytext=1.16.4=pyh80e38bb_0 +- kernel-headers_linux-64=3.10.0=h4a8ded7_16 +- keyutils=1.6.1=h166bdaf_0 +- kiwisolver=1.4.7=py311hd18a35c_0 +- krb5=1.21.3=h659f571_0 +- lcms2=2.16=hb7c19ff_0 +- ld_impl_linux-64=2.40=hf3520f5_7 +- lerc=4.0.0=h27087fc_0 +- libabseil=20240116.2=cxx17_he02047a_1 +- libarrow=17.0.0=h8d2e343_13_cpu +- libarrow-acero=17.0.0=h5888daf_13_cpu +- libarrow-dataset=17.0.0=h5888daf_13_cpu +- libarrow-substrait=17.0.0=hf54134d_13_cpu +- libavif16=1.1.1=h104a339_1 +- libblas=3.9.0=23_linux64_openblas +- libbrotlicommon=1.1.0=hb9d3cd8_2 +- libbrotlidec=1.1.0=hb9d3cd8_2 +- libbrotlienc=1.1.0=hb9d3cd8_2 +- libcblas=3.9.0=23_linux64_openblas +- libclang-cpp18.1=18.1.8=default_hf981a13_4 +- libclang13=18.1.8=default_h9def88c_4 +- libcrc32c=1.1.2=h9c3ff4c_0 +- libcublas=12.6.1.4=h5888daf_0 +- libcufft=11.2.6.59=h5888daf_0 +- libcups=2.3.3=h4637d8d_4 +- libcurand=10.3.7.68=h5888daf_0 +- libcurl=8.9.1=hdb1bdb2_0 +- libcusolver=11.6.4.69=h5888daf_0 +- libcusparse=12.5.3.3=h5888daf_0 +- libdeflate=1.21=h4bc722e_0 +- libdrm=2.4.123=hb9d3cd8_0 +- libedit=3.1.20191231=he28a2e2_2 +- libegl=1.7.0=ha4b6fd6_0 +- libev=4.33=hd590300_2 +- libevent=2.1.12=hf998b51_1 +- libexpat=2.6.3=h5888daf_0 +- libffi=3.4.2=h7f98852_5 +- libgcc=14.1.0=h77fa898_1 +- libgcc-devel_linux-64=13.3.0=h84ea5a7_101 +- libgcc-ng=14.1.0=h69a702a_1 +- libgfortran=14.1.0=h69a702a_1 +- libgfortran-ng=14.1.0=h69a702a_1 +- libgfortran5=14.1.0=hc5f4f2c_1 +- libgl=1.7.0=ha4b6fd6_0 +- libglib=2.80.3=h315aac3_2 +- libglvnd=1.7.0=ha4b6fd6_0 +- libglx=1.7.0=ha4b6fd6_0 +- libgomp=14.1.0=h77fa898_1 +- libgoogle-cloud=2.28.0=h26d7fe4_0 +- libgoogle-cloud-storage=2.28.0=ha262f82_0 +- libgrpc=1.62.2=h15f2491_0 +- libiconv=1.17=hd590300_2 +- libjpeg-turbo=3.0.0=hd590300_1 +- liblapack=3.9.0=23_linux64_openblas +- libllvm18=18.1.8=h8b73ec9_2 +- libnghttp2=1.58.0=h47da74e_1 +- libnsl=2.0.1=hd590300_0 +- libnvjitlink=12.6.68=h5888daf_0 +- libopenblas=0.3.27=pthreads_hac2b453_1 +- libparquet=17.0.0=h39682fd_13_cpu +- libpciaccess=0.18=hd590300_0 +- libpng=1.6.43=h2797004_0 +- libpq=16.4=h2d7952a_1 +- libprotobuf=4.25.3=h08a7969_0 +- libre2-11=2023.09.01=h5a48ba9_2 +- libsanitizer=13.3.0=heb74ff8_1 +- libsodium=1.0.20=h4ab18f5_0 +- libsqlite=3.46.1=hadc24fc_0 +- libssh2=1.11.0=h0841786_0 +- libstdcxx=14.1.0=hc0a3c3a_1 +- libstdcxx-devel_linux-64=13.3.0=h84ea5a7_101 +- libstdcxx-ng=14.1.0=h4852527_1 +- libthrift=0.20.0=h0e7cc3e_1 +- libtiff=4.6.0=h46a8edc_4 +- libutf8proc=2.8.0=h166bdaf_0 +- libuuid=2.38.1=h0b41bf4_0 +- libwebp-base=1.4.0=hd590300_0 +- libxcb=1.16=hb9d3cd8_1 +- libxcrypt=4.4.36=hd590300_1 +- libxkbcommon=1.7.0=h2c5496b_1 +- libxml2=2.12.7=he7c6b58_4 +- libxslt=1.1.39=h76b75d6_0 +- libzlib=1.3.1=h4ab18f5_1 +- lxml=5.3.0=py311hcfaa980_1 +- lz4-c=1.9.4=hcb278e6_0 +- markdown-it-py=3.0.0=pyhd8ed1ab_0 +- markupsafe=2.1.5=py311h9ecbd09_1 +- matplotlib=3.9.2=py311h38be061_0 +- matplotlib-base=3.9.2=py311h74b4f7c_0 +- matplotlib-inline=0.1.7=pyhd8ed1ab_0 +- mdit-py-plugins=0.4.1=pyhd8ed1ab_0 +- mdurl=0.1.2=pyhd8ed1ab_0 +- ml_dtypes=0.4.0=py311h7db5c69_2 +- msgpack-python=1.0.8=py311hd18a35c_1 +- multidict=6.1.0=py311h9ecbd09_0 +- multiprocess=0.70.16=py311h9ecbd09_1 +- munkres=1.1.4=pyh9f0ad1d_0 +- mysql-common=9.0.1=h70512c7_0 +- mysql-libs=9.0.1=ha479ceb_0 +- nbformat=5.10.4=pyhd8ed1ab_0 +- nccl=2.22.3.1=hbc370b7_1 +- ncurses=6.5=he02047a_1 +- nest-asyncio=1.6.0=pyhd8ed1ab_0 +- nltk=3.9.1=pyhd8ed1ab_0 +- numpy=1.26.4=py311h64a7726_0 +- openjpeg=2.5.2=h488ebb8_0 +- openssl=3.3.2=hb9d3cd8_0 +- opt-einsum=3.3.0=hd8ed1ab_2 +- opt_einsum=3.3.0=pyhc1e730c_2 +- optax=0.2.2=pyhd8ed1ab_1 +- orbax-checkpoint=0.4.4=pyhd8ed1ab_0 +- orc=2.0.2=h669347b_0 +- packaging=24.1=pyhd8ed1ab_0 +- pandas=2.2.2=py311h14de704_1 +- parso=0.8.4=pyhd8ed1ab_0 +- pcre2=10.44=hba22ea6_2 +- pexpect=4.9.0=pyhd8ed1ab_0 +- pickleshare=0.7.5=py_1003 +- pillow=10.4.0=py311h82a398c_0 +- pip=24.2=pyh8b19718_1 +- pixman=0.43.2=h59595ed_0 +- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1 +- platformdirs=4.3.2=pyhd8ed1ab_0 +- portalocker=2.10.1=py311h38be061_0 +- prompt-toolkit=3.0.47=pyha770c72_0 +- protobuf=4.25.3=py311hbffca5d_1 +- psutil=6.0.0=py311h9ecbd09_1 +- pthread-stubs=0.4=h36c2ea0_1001 +- ptyprocess=0.7.0=pyhd3deb0d_0 +- pure_eval=0.2.3=pyhd8ed1ab_0 +- pyarrow=17.0.0=py311hbd00459_1 +- pyarrow-core=17.0.0=py311h4510849_1_cpu +- pyarrow-hotfix=0.6=pyhd8ed1ab_0 +- pybind11-abi=4=hd8ed1ab_3 +- pycparser=2.22=pyhd8ed1ab_0 +- pygments=2.18.0=pyhd8ed1ab_0 +- pyparsing=3.1.4=pyhd8ed1ab_0 +- pyside6=6.7.2=py311hba19f1e_2 +- pysocks=1.7.1=pyha2e5f31_6 +- python=3.11.9=hb806964_0_cpython +- python-dateutil=2.9.0=pyhd8ed1ab_0 +- python-fastjsonschema=2.20.0=pyhd8ed1ab_0 +- python-slugify=8.0.4=pyhd8ed1ab_0 +- python-tzdata=2024.1=pyhd8ed1ab_0 +- python-xxhash=3.5.0=py311h9ecbd09_1 +- python_abi=3.11=5_cp311 +- pytz=2024.1=pyhd8ed1ab_0 +- pyyaml=6.0.2=py311h9ecbd09_1 +- pyzmq=26.2.0=py311h7deb3e3_2 +- qhull=2020.2=h434a139_5 +- qt6-main=6.7.2=hb12f9c5_5 +- rav1e=0.6.6=he8a937b_2 +- re2=2023.09.01=h7f4b329_2 +- readline=8.2=h8228510_1 +- referencing=0.35.1=pyhd8ed1ab_0 +- regex=2024.7.24=py311h9ecbd09_1 +- requests=2.32.3=pyhd8ed1ab_0 +- responses=0.18.0=pyhd8ed1ab_0 +- rich=13.7.1=pyhd8ed1ab_0 +- rpds-py=0.20.0=py311h9e33e62_1 +- s2n=1.5.2=h7b32b05_0 +- sacrebleu=2.4.1=pyhd8ed1ab_1 +- safetensors=0.4.5=py311h9e33e62_0 +- scipy=1.14.1=py311he1f765f_0 +- setuptools=73.0.1=pyhd8ed1ab_0 +- six=1.16.0=pyh6c4a22f_0 +- snappy=1.2.1=ha2e4443_0 +- stack_data=0.6.2=pyhd8ed1ab_0 +- svt-av1=2.2.1=h5888daf_0 +- sysroot_linux-64=2.17=h4a8ded7_16 +- tabulate=0.9.0=pyhd8ed1ab_1 +- tensorstore=0.1.62=py311he109767_1 +- text-unidecode=1.3=pyhd8ed1ab_1 +- tk=8.6.13=noxft_h4845f30_101 +- tokenizers=0.19.1=py311h6640629_0 +- tomli=2.0.1=pyhd8ed1ab_0 +- toolz=0.12.1=pyhd8ed1ab_0 +- tornado=6.4.1=py311h9ecbd09_1 +- tqdm=4.66.5=pyhd8ed1ab_0 +- traitlets=5.14.3=pyhd8ed1ab_0 +- transformers=4.41.2=pyhd8ed1ab_0 +- types-python-dateutil=2.9.0.20240906=pyhd8ed1ab_0 +- typing=3.10.0.0=pyhd8ed1ab_1 +- typing-extensions=4.12.2=hd8ed1ab_0 +- typing_extensions=4.12.2=pyha770c72_0 +- tzdata=2024a=h8827d51_1 +- urllib3=2.2.2=pyhd8ed1ab_1 +- wayland=1.23.1=h3e06ad9_0 +- wcwidth=0.2.13=pyhd8ed1ab_0 +- wheel=0.44.0=pyhd8ed1ab_0 +- xcb-util=0.4.1=hb711507_2 +- xcb-util-cursor=0.1.4=h4ab18f5_2 +- xcb-util-image=0.4.0=hb711507_2 +- xcb-util-keysyms=0.4.1=hb711507_0 +- xcb-util-renderutil=0.3.10=hb711507_0 +- xcb-util-wm=0.4.2=hb711507_0 +- xkeyboard-config=2.42=h4ab18f5_0 +- xorg-fixesproto=5.0=h7f98852_1002 +- xorg-inputproto=2.3.2=h7f98852_1002 +- xorg-kbproto=1.0.7=h7f98852_1002 +- xorg-libice=1.1.1=hd590300_0 +- xorg-libsm=1.2.4=h7391055_0 +- xorg-libx11=1.8.9=hb711507_1 +- xorg-libxau=1.0.11=hd590300_0 +- xorg-libxdmcp=1.1.3=h7f98852_0 +- xorg-libxext=1.3.4=h0b41bf4_2 +- xorg-libxfixes=5.0.3=h7f98852_1004 +- xorg-libxi=1.7.10=h4bc722e_1 +- xorg-libxrender=0.9.11=hd590300_0 +- xorg-libxtst=1.2.5=h4bc722e_0 +- xorg-libxxf86vm=1.1.5=h4bc722e_1 +- xorg-recordproto=1.14.2=h7f98852_1002 +- xorg-renderproto=0.11.1=h7f98852_1002 +- xorg-xextproto=7.3.0=h0b41bf4_1003 +- xorg-xproto=7.0.31=h7f98852_1007 +- xxhash=0.8.2=hd590300_0 +- xz=5.2.6=h166bdaf_0 +- yaml=0.2.5=h7f98852_2 +- yarl=1.11.1=py311h9ecbd09_0 +- zeromq=4.3.5=ha4adb4c_5 +- zipp=3.20.1=pyhd8ed1ab_0 +- zlib=1.3.1=h4ab18f5_1 +- zstandard=0.23.0=py311hbc35293_1 +- zstd=1.5.6=ha6fb4c9_0 + diff --git a/t5_jax.py b/t5_jax.py new file mode 100644 index 0000000..240b005 --- /dev/null +++ b/t5_jax.py @@ -0,0 +1,620 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.4 +# kernelspec: +# display_name: jax +# language: python +# name: python3 +# --- + +# %% [markdown] +# # T5 implementation using jax + +# %% [markdown] +# ## import + +# %% [raw] +# import json +# import logging +# import math +# import os +# import sys +# import time +# from dataclasses import asdict, dataclass, field +# from enum import Enum +# from functools import partial +# from pathlib import Path +# from typing import Callable, Optional +# +# import datasets +# import evaluate +# import jax +# import jax.numpy as jnp +# import nltk # Here to have a nice missing dependency error message early on +# import numpy as np +# import optax +# from datasets import Dataset, load_dataset +# from filelock import FileLock +# from flax import jax_utils, traverse_util +# from flax.jax_utils import pad_shard_unpad, unreplicate +# from flax.training import train_state +# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +# from tqdm import tqdm +# +# import transformers +# from transformers import ( +# CONFIG_MAPPING, +# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, +# AutoConfig, +# AutoTokenizer, +# FlaxAutoModelForSeq2SeqLM, +# HfArgumentParser, +# is_tensorboard_available, +# ) +# from transformers.utils import is_offline_mode, send_example_telemetry +# +# +# logger = logging.getLogger(__name__) +# +# try: +# nltk.data.find("tokenizers/punkt") +# except (LookupError, OSError): +# if is_offline_mode(): +# raise LookupError( +# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" +# ) +# with FileLock(".lock") as lock: +# nltk.download("punkt", quiet=True) +# +# +# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) +# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +# %% +import jax +import jax.numpy as jnp +import optax +import numpy as np +from functools import partial +from typing import Callable, Optional +import math + +# jax.config.update("jax_default_matmul_precision", "tensorfloat32") +jax.config.update("jax_default_matmul_precision", "high") + +jax.config.update("jax_enable_x64", False) + + +from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig + + +import datasets +from datasets import Dataset, load_dataset +import evaluate +from tqdm import tqdm +from datasets import load_from_disk + + +import nltk # Here to have a nice missing dependency error message early on + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key + + +import time + + +# %% +import os +os.environ['XLA_FLAGS'] = ( + '--xla_gpu_enable_triton_softmax_fusion=True ' + '--xla_gpu_triton_gemm_any=True ' +) + +os.environ.update({ + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + }) + +# %% +from jax.lib import xla_bridge +print(xla_bridge.get_backend().platform) + + +# %% +# nltk.download('punkt') +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + + +# %% [markdown] +# ## Prepare datasets + +# %% +# load model +model_name_or_path = "t5-small" # Replace with your specific model name + +# Load configuration +config = AutoConfig.from_pretrained(model_name_or_path) + +# Load model +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path +) + + +# %% +model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") + + +# %% +# Path to saved combined_dataset +file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' +save_path = 't5_80_1' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) + +# prepare tokenizer +from transformers import T5TokenizerFast +tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) +# Define additional special tokens +additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] +# Add the additional special tokens to the tokenizer +tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + +max_length = 86 + +# In Flax, for seq2seq models we need to pass `decoder_input_ids` +# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here +# for that dynamically import the `shift_tokens_right` function from the model file + + +# given a dataset entry, run it through the tokenizer +# Setting padding="max_length" as we need fixed length inputs for jitted functions +def preprocess_function(example): + input = example['input'] + target = example['output'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + input, + text_target=target, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + labels = tokenizer( + input, + text_target=target, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + + model_inputs["labels"] = labels["input_ids"] + decoder_input_ids = shift_tokens_right_fn( + labels["input_ids"], config.pad_token_id, config.decoder_start_token_id + ) + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + +# map maps function to each "row" in the dataset +# aka the data in the immediate nesting +tokenized_datasets = split_datasets.map( + preprocess_function, + batched=True, + num_proc=1, + remove_columns=split_datasets["train"].column_names, +) + + + + +# %% +tokenized_datasets + +# %% +train_dataset = tokenized_datasets["train"] +eval_dataset = tokenized_datasets["validation"] + + +# %% +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: np.array(v) for k, v in batch.items()} + + yield batch + + + +# %% [markdown] +# Now we have model inputs in terms of the variable tokenized_datasets + +# %% +# metric +metric = evaluate.load("sacrebleu") + +def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + +# def compute_metrics(preds, labels): +# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) +# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) +# +# # Some simple post-processing +# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) +# +# result = metric.compute(predictions=decoded_preds, references=decoded_labels) +# result = {k: round(v * 100, 4) for k, v in result.items()} +# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] +# result["gen_len"] = np.mean(prediction_lens) +# return result + +def compute_metrics(preds, labels): + # In case the model returns more than the prediction logits + if isinstance(preds, tuple): + preds = preds[0] + + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Replace -100s in the labels as we can't decode them + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Some simple post-processing + decoded_preds = [pred.strip() for pred in decoded_preds] + decoded_labels = [[label.strip()] for label in decoded_labels] + + result = metric.compute(predictions=decoded_preds, references=decoded_labels) + return {"bleu": result["score"]} + + + +# %% [markdown] +# # Model + +# %% +# Store some constant +seed = 117 +num_epochs = 80 +batch_size = 96 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = len(train_dataset) // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 5e-5 + +weight_decay = 0.0 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = None + +predict_with_generate = True + + +# %% + +# Initialize our training +rng = jax.random.PRNGKey(seed) +rng, dropout_rng = jax.random.split(rng) + + +# %% +# optimization functions + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + + +# %% +# Training functions +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + +# Ensure model.params is properly initialized (this is just an example) +# Normally you would get this from a model initialization call with dummy input +params = model.params +# Cast parameters to bfloat16 if desired +params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + + +# Setup train state +state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) + +# label smoothed cross entropy +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + +# Define gradient update step fn +def train_step(state, batch, label_smoothing_factor=0.0): + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels + + # compute gradients through computational graph + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state, metrics + +# Define eval fn +def eval_step(params, batch, label_smoothing_factor=0.0): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + metrics = {"loss": loss} + return metrics + +# Define generation function +max_length = ( + val_max_target_length if val_max_target_length is not None else model.config.max_length +) +num_beams = num_beams if num_beams is not None else model.config.num_beams +gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +def generate_step(params, batch): + model.params = params + output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) + return output_ids.sequences + +# Create parallel version of the train and eval step +p_train_step = jax.pmap( + partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) +) +p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") +p_generate_step = jax.pmap(generate_step, "batch") + +# Replicate the train state on each device +state = state.replicate() + + + +# %% + + +print("***** Running training *****") +print(f" Num examples = {len(train_dataset)}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% + +train_time = 0 +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +# epochs = range(num_epochs) +for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + train_metrics = [] + + # Generate an epoch by shuffling sampling indices from the train dataset + train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) + steps_per_epoch = len(train_dataset) // train_batch_size + # train + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = shard(batch) + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + train_time += time.time() - train_start + + train_metric = unreplicate(train_metric) + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" + f" {train_metric['learning_rate']})" + ) + + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) + eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # Model forward + batch = next(eval_loader) + labels = batch["labels"] + + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) + eval_metrics.append(metrics) + + # generation + if predict_with_generate: + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) + eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + + # compute metrics + rouge_desc = "" + if predict_with_generate: + rouge_metrics = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(rouge_metrics) + rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) + + # Print metrics and update progress bar + desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + # if has_tensorboard and jax.process_index() == 0: + # cur_step = epoch * (len(train_dataset) // train_batch_size) + # write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + + output_dir = save_path + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(output_dir, params=params) + tokenizer.save_pretrained(output_dir) + + + +# %% [markdown] +# # diff --git a/t5_jax_prediction.py b/t5_jax_prediction.py new file mode 100644 index 0000000..217372b --- /dev/null +++ b/t5_jax_prediction.py @@ -0,0 +1,386 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.4 +# --- + +# %% [markdown] +# # prediction code +# ## import and process test data + + +# %% +# import libraries +import pandas as pd +import matplotlib.pyplot as plt + +from datasets import Dataset, DatasetDict + +import jax +import jax.numpy as jnp +import optax +import numpy as np +from functools import partial +from typing import Callable, Optional +import math + +# jax.config.update("jax_default_matmul_precision", "tensorfloat32") +jax.config.update("jax_default_matmul_precision", "high") + +jax.config.update("jax_enable_x64", False) + + +from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig + + +import datasets +from datasets import Dataset, load_dataset +import evaluate +from tqdm import tqdm +from datasets import load_from_disk + + +import nltk # Here to have a nice missing dependency error message early on + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key + + +import time + + +# %% + +# data_path = f"../make_data/select_db/data_mapping_filtered.csv" +# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" +data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' +# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' + +# Ensure to include 'ships_idx' in the fields list +fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] + +# Load the dataset +df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) + +def process_df(df): + output_list = [{ + 'input': f"{row['tag_name']}{row['tag_description']}", + # 'input': f"{row['tag_description']}", + # 'input': f"{row['tag_name']}{row['tag_description']}{row['unit']}", + # 'input': f"{row['tag_description']}{row['unit']}", + 'output': f"{row['thing']}{row['property']}", + 'answer': f"{row['thing']} {row['property']}", + 'answer_thing': row['thing'], + 'answer_property': row['property'], + } for _, row in df.iterrows()] + + return output_list + + +# takes 1 minute to run without batching +test_dataset = Dataset.from_list(process_df(df)) + + +# %% [markdown] +# ## Load model for attributes + +# %% +# load model +model_name_or_path = "t5_80_1" # Replace with your specific model name + +# Load configuration +config = AutoConfig.from_pretrained(model_name_or_path) + +# Load model +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path +) + + +# %% [markdown] +# ## Tokenizer + +# %% +# prepare tokenizer +from transformers import T5TokenizerFast +tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) +# Define additional special tokens +additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] +# Add the additional special tokens to the tokenizer +tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + +max_length = 86 + +model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") + +# given a dataset entry, run it through the tokenizer +# Setting padding="max_length" as we need fixed length inputs for jitted functions +def preprocess_function(example): + input = example['input'] + target = example['output'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + input, + text_target=target, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + labels = tokenizer( + input, + text_target=target, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + + model_inputs["labels"] = labels["input_ids"] + decoder_input_ids = shift_tokens_right_fn( + labels["input_ids"], config.pad_token_id, config.decoder_start_token_id + ) + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + +# map maps function to each "row" in the dataset +# aka the data in the immediate nesting +test_dataset = test_dataset.map( + preprocess_function, + batched=True, + num_proc=1, + remove_columns=test_dataset.column_names, +) + +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: np.array(v) for k, v in batch.items()} + + yield batch + +# %% [markdown] +# # Model Training + +# %% +seed = 117 +num_epochs = 80 +batch_size = 96 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = len(test_dataset) // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 5e-5 + +weight_decay = 0.0 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = None + +predict_with_generate = True + + +# Initialize our training +rng = jax.random.PRNGKey(seed) +rng, dropout_rng = jax.random.split(rng) + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(test_dataset), + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + +# %% + +# reload model to prevent leakage of variables +# load model +model_name_or_path = "t5_80_1" # Replace with your specific model name + +# Load configuration +config = AutoConfig.from_pretrained(model_name_or_path) + +# Load model +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path +) + +# Training functions +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + +# Ensure model.params is properly initialized (this is just an example) +# Normally you would get this from a model initialization call with dummy input +params = model.params +# Cast parameters to bfloat16 if desired +params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + + +# Setup train state +state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) + + + +# Define generation function +max_length = ( + val_max_target_length if val_max_target_length is not None else model.config.max_length +) +num_beams = num_beams if num_beams is not None else model.config.num_beams +gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +def generate_step(params, batch): + model.params = params + output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) + return output_ids.sequences + +# Create parallel version of the train and eval step +p_generate_step = jax.pmap(generate_step, "batch") + +# Replicate the train state on each device +state = state.replicate() + + +pred_metrics = [] +pred_generations = [] +pred_labels = [] + +rng, input_rng = jax.random.split(rng) + +pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) +pred_steps = math.ceil(len(test_dataset) / eval_batch_size) + +print("***** Running training *****") +print(f" Num examples = {len(test_dataset)}") +print(f" Num steps = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") + + +for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False): + # Model forward + batch = next(pred_loader) + labels = batch["labels"] + + # generation + if predict_with_generate: + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) + pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + pred_labels.extend(labels) + + + +# Print metrics +# desc = f"Predict Loss: {pred_metrics['loss']})" +# print(desc) + +# %% +# save predictions to parquet + +# decode prediction labels +def decode_preds(preds): + # In case the model returns more than the prediction logits + if isinstance(preds, tuple): + preds = preds[0] + + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + decoded_preds = [pred for pred in decoded_preds] + + return decoded_preds + + +# Convert the list to a Pandas DataFrame +df = pd.DataFrame(decode_preds(pred_labels)) + +# Save the DataFrame as a Parquet file (using pyarrow or fastparquet) +df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet" + + + +# %% diff --git a/t5_jax_retrieval.py b/t5_jax_retrieval.py new file mode 100644 index 0000000..36b720c --- /dev/null +++ b/t5_jax_retrieval.py @@ -0,0 +1,624 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.4 +# kernelspec: +# display_name: jax +# language: python +# name: python3 +# --- + +# %% [markdown] +# # T5 implementation using jax + +# %% [markdown] +# ## import + +# %% [raw] +# import json +# import logging +# import math +# import os +# import sys +# import time +# from dataclasses import asdict, dataclass, field +# from enum import Enum +# from functools import partial +# from pathlib import Path +# from typing import Callable, Optional +# +# import datasets +# import evaluate +# import jax +# import jax.numpy as jnp +# import nltk # Here to have a nice missing dependency error message early on +# import numpy as np +# import optax +# from datasets import Dataset, load_dataset +# from filelock import FileLock +# from flax import jax_utils, traverse_util +# from flax.jax_utils import pad_shard_unpad, unreplicate +# from flax.training import train_state +# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +# from tqdm import tqdm +# +# import transformers +# from transformers import ( +# CONFIG_MAPPING, +# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, +# AutoConfig, +# AutoTokenizer, +# FlaxAutoModelForSeq2SeqLM, +# HfArgumentParser, +# is_tensorboard_available, +# ) +# from transformers.utils import is_offline_mode, send_example_telemetry +# +# +# logger = logging.getLogger(__name__) +# +# try: +# nltk.data.find("tokenizers/punkt") +# except (LookupError, OSError): +# if is_offline_mode(): +# raise LookupError( +# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" +# ) +# with FileLock(".lock") as lock: +# nltk.download("punkt", quiet=True) +# +# +# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) +# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +# %% +import jax +import jax.numpy as jnp +import optax +import numpy as np +from functools import partial +from typing import Callable, Optional +import math + +# jax.config.update("jax_default_matmul_precision", "tensorfloat32") +jax.config.update("jax_default_matmul_precision", "high") + +jax.config.update("jax_enable_x64", False) + + +from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig + + +import datasets +from datasets import Dataset, load_dataset +import evaluate + + +import nltk # Here to have a nice missing dependency error message early on + +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key + + +import time + + +# %% +import os +os.environ['XLA_FLAGS'] = ( + '--xla_gpu_enable_triton_softmax_fusion=True ' + '--xla_gpu_triton_gemm_any=True ' +) + +os.environ.update({ + "CUDA_VISIBLE_DEVICES": "0, 1, 2, 3", + "NCCL_LL128_BUFFSIZE": "-2", + "NCCL_LL_BUFFSIZE": "-2", + "NCCL_PROTO": "SIMPLE,LL,LL128", + }) + +# %% +from jax.lib import xla_bridge +print(xla_bridge.get_backend().platform) + + +# %% +# nltk.download('punkt') +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + + +# %% [markdown] +# ## Prepare datasets + +# %% +# load model +model_name_or_path = "t5-small" # Replace with your specific model name + +# Load configuration +config = AutoConfig.from_pretrained(model_name_or_path) + +# Load model +model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path +) + + +# %% +model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) +shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") + + +# %% +from tqdm import tqdm +from datasets import load_from_disk +# Path to saved combined_dataset +file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5' +save_path = 't5_80_1_retrieval' +# file_path = 'combined_data' +split_datasets = load_from_disk(file_path) + +# prepare tokenizer +from transformers import T5TokenizerFast +tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) +# Define additional special tokens +# additional_special_tokens = ["", "", "", "", "", "", "SIG", "UNIT", "DATA_TYPE"] +# Define additional special tokens +additional_special_tokens = ["", "", "", "", "", "", + "", "", "", ""] +# Add the additional special tokens to the tokenizer +tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) + +max_length = 300 + +# In Flax, for seq2seq models we need to pass `decoder_input_ids` +# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here +# for that dynamically import the `shift_tokens_right` function from the model file + + +# given a dataset entry, run it through the tokenizer +# Setting padding="max_length" as we need fixed length inputs for jitted functions +def preprocess_function(example): + input = example['input'] + target = example['output'] + # text_target sets the corresponding label to inputs + # there is no need to create a separate 'labels' + model_inputs = tokenizer( + input, + text_target=target, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + labels = tokenizer( + input, + text_target=target, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="np" + ) + + model_inputs["labels"] = labels["input_ids"] + decoder_input_ids = shift_tokens_right_fn( + labels["input_ids"], config.pad_token_id, config.decoder_start_token_id + ) + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + +# map maps function to each "row" in the dataset +# aka the data in the immediate nesting +tokenized_datasets = split_datasets.map( + preprocess_function, + batched=True, + num_proc=1, + remove_columns=split_datasets["train"].column_names, +) + + + + +# %% +tokenized_datasets + +# %% +train_dataset = tokenized_datasets["train"] +eval_dataset = tokenized_datasets["validation"] + + +# %% +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: np.array(v) for k, v in batch.items()} + + yield batch + + + +# %% [markdown] +# Now we have model inputs in terms of the variable tokenized_datasets + +# %% +# metric +metric = evaluate.load("sacrebleu") + +def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + +# def compute_metrics(preds, labels): +# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) +# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) +# +# # Some simple post-processing +# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) +# +# result = metric.compute(predictions=decoded_preds, references=decoded_labels) +# result = {k: round(v * 100, 4) for k, v in result.items()} +# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] +# result["gen_len"] = np.mean(prediction_lens) +# return result + +def compute_metrics(preds, labels): + # In case the model returns more than the prediction logits + if isinstance(preds, tuple): + preds = preds[0] + + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Replace -100s in the labels as we can't decode them + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Some simple post-processing + decoded_preds = [pred.strip() for pred in decoded_preds] + decoded_labels = [[label.strip()] for label in decoded_labels] + + result = metric.compute(predictions=decoded_preds, references=decoded_labels) + return {"bleu": result["score"]} + + + +# %% [markdown] +# # Model + +# %% +# Store some constant +seed = 117 +num_epochs = 80 +batch_size = 36 +num_train_epochs = num_epochs +per_device_train_batch_size = batch_size +train_batch_size = per_device_train_batch_size * jax.device_count() +per_device_eval_batch_size = batch_size +eval_batch_size = per_device_eval_batch_size * jax.device_count() +steps_per_epoch = len(train_dataset) // train_batch_size +total_train_steps = steps_per_epoch * num_epochs + +warmup_steps = 0 +learning_rate = 5e-5 + +weight_decay = 0.0 +adam_beta1 = 0.9 +adam_beta2 = 0.999 +adam_epsilon = 1e-8 +label_smoothing_factor = 0.0 + +num_beams = 1 +val_max_target_length = None + +predict_with_generate = True + + +# %% + +# Initialize our training +rng = jax.random.PRNGKey(seed) +rng, dropout_rng = jax.random.split(rng) + + +# %% +# optimization functions + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +# Create learning rate schedule +linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + num_train_epochs, + warmup_steps, + learning_rate, +) + +# We use Optax's "masking" functionality to not apply weight decay +# to bias and LayerNorm scale parameters. decay_mask_fn returns a +# mask boolean with the same structure as the parameters. +# The mask is True for parameters that should be decayed. +def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + +# create adam optimizer +adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=adam_beta1, + b2=adam_beta2, + eps=adam_epsilon, + weight_decay=weight_decay, + mask=decay_mask_fn, +) + + +# %% +# Training functions +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + +# Ensure model.params is properly initialized (this is just an example) +# Normally you would get this from a model initialization call with dummy input +params = model.params +# Cast parameters to bfloat16 if desired +params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + + +# Setup train state +state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) + +# label smoothed cross entropy +def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + +# Define gradient update step fn +def train_step(state, batch, label_smoothing_factor=0.0): + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels + + # compute gradients through computational graph + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state, metrics + +# Define eval fn +def eval_step(params, batch, label_smoothing_factor=0.0): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + metrics = {"loss": loss} + return metrics + +# Define generation function +max_length = ( + val_max_target_length if val_max_target_length is not None else model.config.max_length +) +num_beams = num_beams if num_beams is not None else model.config.num_beams +gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +def generate_step(params, batch): + model.params = params + output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) + return output_ids.sequences + +# Create parallel version of the train and eval step +p_train_step = jax.pmap( + partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) +) +p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") +p_generate_step = jax.pmap(generate_step, "batch") + +# Replicate the train state on each device +state = state.replicate() + + + +# %% + + +print("***** Running training *****") +print(f" Num examples = {len(train_dataset)}") +print(f" Num Epochs = {num_epochs}") +print(f" Instantaneous batch size per device = {per_device_train_batch_size}") +print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") +print(f" Total optimization steps = {total_train_steps}") + + +# %% + +train_time = 0 +epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) +# epochs = range(num_epochs) +for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + train_metrics = [] + + # Generate an epoch by shuffling sampling indices from the train dataset + train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) + steps_per_epoch = len(train_dataset) // train_batch_size + # train + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = shard(batch) + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + train_time += time.time() - train_start + + train_metric = unreplicate(train_metric) + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" + f" {train_metric['learning_rate']})" + ) + + # ======================== Evaluating ============================== + # eval_metrics = [] + # eval_preds = [] + # eval_labels = [] + + # eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) + # eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) + # for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # # Model forward + # batch = next(eval_loader) + # labels = batch["labels"] + + # metrics = pad_shard_unpad(p_eval_step, static_return=True)( + # state.params, batch, min_device_batch=per_device_eval_batch_size + # ) + # eval_metrics.append(metrics) + + # # generation + # if predict_with_generate: + # generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) + # eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + # eval_labels.extend(labels) + + # # normalize eval metrics + # eval_metrics = get_metrics(eval_metrics) + # eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + + # compute metrics + # rouge_desc = "" + # if predict_with_generate: + # rouge_metrics = compute_metrics(eval_preds, eval_labels) + # eval_metrics.update(rouge_metrics) + # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) + + # # Print metrics and update progress bar + # desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" + # epochs.write(desc) + # epochs.desc = desc + + # Save metrics + # if has_tensorboard and jax.process_index() == 0: + # cur_step = epoch * (len(train_dataset) // train_batch_size) + # write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + + output_dir = save_path + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(output_dir, params=params) + tokenizer.save_pretrained(output_dir) + + + +# %% [markdown] +# # diff --git a/t5_summarizer_flax.py b/t5_summarizer_flax.py new file mode 100644 index 0000000..35dfccf --- /dev/null +++ b/t5_summarizer_flax.py @@ -0,0 +1,1021 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for summarization. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import json +import logging +import math +import os +import sys +import time +from dataclasses import asdict, dataclass, field +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Callable, Optional + +import datasets +import evaluate +import jax +import jax.numpy as jnp +import nltk # Here to have a nice missing dependency error message early on +import numpy as np +import optax +from datasets import Dataset, load_dataset +from filelock import FileLock +from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad, unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from huggingface_hub import HfApi +from tqdm import tqdm + +import transformers +from transformers import ( + CONFIG_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + AutoConfig, + AutoTokenizer, + FlaxAutoModelForSeq2SeqLM, + HfArgumentParser, + is_tensorboard_available, +) +from transformers.utils import is_offline_mode, send_example_telemetry + + +logger = logging.getLogger(__name__) + +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class TrainingArguments: + output_dir: str = field( + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": ( + "Overwrite the content of the output directory. " + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, + ) + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) + per_device_train_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} + ) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} + ) + learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + label_smoothing_factor: float = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} + ) + adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) + num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) + save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) + eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + push_to_hub: bool = field( + default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} + ) + hub_model_id: str = field( + default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} + ) + hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + + def __post_init__(self): + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + d = asdict(self) + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + return d + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": ( + "Floating-point format in which the model weights should be initialized and trained. Choose one of" + " `[float32, float16, bfloat16]`." + ) + }, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + summary_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input predict data file to do prediction on (a text file)."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": ( + "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": ( + "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`. " + "This argument is also used to override the `max_length` param of `model.generate`, which is used " + "during evaluation." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + ) + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + num_beams: Optional[int] = field( + default=1, + metadata={ + "help": ( + "Number of beams to use for evaluation. This argument will be passed to `model.generate`, " + "which is used during evaluation." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + + def __post_init__(self): + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training, validation, or test file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +summarization_name_mapping = { + "amazon_reviews_multi": ("review_body", "review_title"), + "big_patent": ("description", "abstract"), + "cnn_dailymail": ("article", "highlights"), + "orange_sum": ("text", "summary"), + "pn_summary": ("article", "summary"), + "psc": ("extract_text", "summary_text"), + "samsum": ("dialogue", "summary"), + "thaisum": ("body", "summary"), + "xglue": ("news_body", "news_title"), + "xsum": ("document", "summary"), + "wiki_summary": ("article", "highlights"), +} + + +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): + """ + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. + """ + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) + else: + batch_idx = np.arange(len(dataset)) + + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: np.array(v) for k, v in batch.items()} + + yield batch + + +def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.ndarray]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_summarization", model_args, data_args, framework="flax") + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Handle the repository creation + if training_args.push_to_hub: + # Retrieve of infer repo_name + repo_name = training_args.hub_model_id + if repo_name is None: + repo_name = Path(training_args.output_dir).absolute().name + # Create repo and retrieve repo_id + api = HfApi() + repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files this script will use the first column for the full texts and the second column for the + # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). + # + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + keep_in_memory=False, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + dataset = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets. + + # Load pretrained model and tokenizer + + if model_args.config_name: + config = AutoConfig.from_pretrained( + model_args.config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script. " + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if model_args.model_name_or_path: + model = FlaxAutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + config=config, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype), + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + else: + model = FlaxAutoModelForSeq2SeqLM.from_config( + config, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype), + trust_remote_code=model_args.trust_remote_code, + ) + + if training_args.gradient_checkpointing: + model.enable_gradient_checkpointing() + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + if "train" not in dataset: + raise ValueError("--do_train requires a train dataset") + column_names = dataset["train"].column_names + elif training_args.do_eval: + if "validation" not in dataset: + raise ValueError("--do_eval requires a validation dataset") + column_names = dataset["validation"].column_names + elif training_args.do_predict: + if "test" not in dataset: + raise ValueError("--do_predict requires a test dataset") + column_names = dataset["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # Get the column names for input/target. + dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) + if data_args.text_column is None: + text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + text_column = data_args.text_column + if text_column not in column_names: + raise ValueError( + f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" + ) + if data_args.summary_column is None: + summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + summary_column = data_args.summary_column + if summary_column not in column_names: + raise ValueError( + f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + + # In Flax, for seq2seq models we need to pass `decoder_input_ids` + # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here + # for that dynamically import the `shift_tokens_right` function from the model file + model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) + shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") + + # Setting padding="max_length" as we need fixed length inputs for jitted functions + def preprocess_function(examples): + inputs = examples[text_column] + targets = examples[summary_column] + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer( + inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np" + ) + + # Setup the tokenizer for targets + labels = tokenizer( + text_target=targets, + max_length=max_target_length, + padding="max_length", + truncation=True, + return_tensors="np", + ) + + model_inputs["labels"] = labels["input_ids"] + decoder_input_ids = shift_tokens_right_fn( + labels["input_ids"], config.pad_token_id, config.decoder_start_token_id + ) + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + + if training_args.do_train: + train_dataset = dataset["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + eval_dataset = dataset["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + predict_dataset = dataset["test"] + if data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + predict_dataset = predict_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) + + # Metric + metric = evaluate.load("rouge", cache_dir=model_args.cache_dir) + + def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + def compute_metrics(preds, labels): + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Some simple post-processing + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + + result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) + result = {k: round(v * 100, 4) for k, v in result.items()} + prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] + result["gen_len"] = np.mean(prediction_lens) + return result + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constant + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() + steps_per_epoch = len(train_dataset) // train_batch_size + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + # find out all LayerNorm parameters + layer_norm_candidates = ["layernorm", "layer_norm", "ln"] + layer_norm_named_params = { + layer[-2:] + for layer_norm_name in layer_norm_candidates + for layer in flat_params.keys() + if layer_norm_name in "".join(layer).lower() + } + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + # create adam optimizer + adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Setup train state + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) + + # label smoothed cross entropy + def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels + + # Define gradient update step fn + def train_step(state, batch, label_smoothing_factor=0.0): + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels + + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + return new_state, metrics + + # Define eval fn + def eval_step(params, batch, label_smoothing_factor=0.0): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + metrics = {"loss": loss} + return metrics + + # Define generation function + max_length = ( + data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams + gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + + def generate_step(params, batch): + model.params = params + output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) + return output_ids.sequences + + # Create parallel version of the train and eval step + p_train_step = jax.pmap( + partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,) + ) + p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch") + p_generate_step = jax.pmap(generate_step, "batch") + + # Replicate the train state on each device + state = state.replicate() + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") + logger.info(f" Total optimization steps = {total_train_steps}") + + train_time = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + train_metrics = [] + + # Generate an epoch by shuffling sampling indices from the train dataset + train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) + steps_per_epoch = len(train_dataset) // train_batch_size + # train + for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch = shard(batch) + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + train_time += time.time() - train_start + + train_metric = unreplicate(train_metric) + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" + f" {train_metric['learning_rate']})" + ) + + # ======================== Evaluating ============================== + eval_metrics = [] + eval_preds = [] + eval_labels = [] + + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) + eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # Model forward + batch = next(eval_loader) + labels = batch["labels"] + + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) + eval_metrics.append(metrics) + + # generation + if data_args.predict_with_generate: + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) + eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + eval_labels.extend(labels) + + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) + + # compute ROUGE metrics + rouge_desc = "" + if data_args.predict_with_generate: + rouge_metrics = compute_metrics(eval_preds, eval_labels) + eval_metrics.update(rouge_metrics) + rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) + + # Print metrics and update progress bar + desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" + epochs.write(desc) + epochs.desc = desc + + # Save metrics + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(train_dataset) // train_batch_size) + write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + api.upload_folder( + commit_message=f"Saving weights and logs of epoch {epoch}", + folder_path=training_args.output_dir, + repo_id=repo_id, + repo_type="model", + token=training_args.hub_token, + ) + + # ======================== Prediction loop ============================== + if training_args.do_predict: + logger.info("*** Predict ***") + + pred_metrics = [] + pred_generations = [] + pred_labels = [] + + pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False) + pred_steps = math.ceil(len(predict_dataset) / eval_batch_size) + for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): + # Model forward + batch = next(pred_loader) + labels = batch["labels"] + + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) + pred_metrics.append(metrics) + + # generation + if data_args.predict_with_generate: + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) + pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + pred_labels.extend(labels) + + # normalize prediction metrics + pred_metrics = get_metrics(pred_metrics) + pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics) + + # compute ROUGE metrics + rouge_desc = "" + if data_args.predict_with_generate: + rouge_metrics = compute_metrics(pred_generations, pred_labels) + pred_metrics.update(rouge_metrics) + rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()]) + + # Print metrics + desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" + logger.info(desc) + + # save final metrics in json + if jax.process_index() == 0: + rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()} + path = os.path.join(training_args.output_dir, "test_results.json") + with open(path, "w") as f: + json.dump(rouge_metrics, f, indent=4, sort_keys=True) + + +if __name__ == "__main__": + main() \ No newline at end of file