Feat: jax implementation of t5 training and prediction
This commit is contained in:
		
						commit
						f523560141
					
				|  | @ -0,0 +1,3 @@ | |||
| *.ipynb | ||||
| t5_*/ | ||||
| exports/ | ||||
|  | @ -0,0 +1,357 @@ | |||
| name: jax | ||||
| channels: | ||||
| - conda-forge | ||||
| dependencies: | ||||
| - _libgcc_mutex=0.1=conda_forge | ||||
| - _openmp_mutex=4.5=2_gnu | ||||
| - _sysroot_linux-64_curr_repodata_hack=3=h69a702a_16 | ||||
| - absl-py=2.1.0=pyhd8ed1ab_0 | ||||
| - aiohappyeyeballs=2.4.0=pyhd8ed1ab_0 | ||||
| - aiohttp=3.10.5=py311h61187de_0 | ||||
| - aiosignal=1.3.1=pyhd8ed1ab_0 | ||||
| - alsa-lib=1.2.12=h4ab18f5_0 | ||||
| - aom=3.9.1=hac33072_0 | ||||
| - arrow=1.3.0=pyhd8ed1ab_0 | ||||
| - asttokens=2.4.1=pyhd8ed1ab_0 | ||||
| - attrs=24.2.0=pyh71513ae_0 | ||||
| - aws-c-auth=0.7.29=h03582ad_1 | ||||
| - aws-c-cal=0.7.4=hfd43aa1_1 | ||||
| - aws-c-common=0.9.28=hb9d3cd8_0 | ||||
| - aws-c-compression=0.2.19=h756ea98_1 | ||||
| - aws-c-event-stream=0.4.3=h235a6dd_1 | ||||
| - aws-c-http=0.8.8=h5e77a74_2 | ||||
| - aws-c-io=0.14.18=hc2627b9_9 | ||||
| - aws-c-mqtt=0.10.4=h01636a3_19 | ||||
| - aws-c-s3=0.6.5=h191b246_2 | ||||
| - aws-c-sdkutils=0.1.19=h756ea98_3 | ||||
| - aws-checksums=0.1.18=h756ea98_11 | ||||
| - aws-crt-cpp=0.28.2=h29c84ef_4 | ||||
| - aws-sdk-cpp=1.11.379=h5a9005d_9 | ||||
| - azure-core-cpp=1.13.0=h935415a_0 | ||||
| - azure-identity-cpp=1.8.0=hd126650_2 | ||||
| - azure-storage-blobs-cpp=12.12.0=hd2e3451_0 | ||||
| - azure-storage-common-cpp=12.7.0=h10ac4d7_1 | ||||
| - azure-storage-files-datalake-cpp=12.11.0=h325d260_1 | ||||
| - binaryornot=0.4.4=py_1 | ||||
| - binutils_impl_linux-64=2.40=ha1999f0_7 | ||||
| - binutils_linux-64=2.40=hb3c18ed_2 | ||||
| - blosc=1.21.6=hef167b5_0 | ||||
| - brotli=1.1.0=hb9d3cd8_2 | ||||
| - brotli-bin=1.1.0=hb9d3cd8_2 | ||||
| - brotli-python=1.1.0=py311hfdbb021_2 | ||||
| - bzip2=1.0.8=h4bc722e_7 | ||||
| - c-ares=1.33.1=heb4867d_0 | ||||
| - ca-certificates=2024.8.30=hbcca054_0 | ||||
| - cairo=1.18.0=hebfffa5_3 | ||||
| - certifi=2024.8.30=pyhd8ed1ab_0 | ||||
| - cffi=1.17.1=py311hf29c0ef_0 | ||||
| - chardet=5.2.0=py311h38be061_2 | ||||
| - charset-normalizer=3.3.2=pyhd8ed1ab_0 | ||||
| - chex=0.1.86=pyhd8ed1ab_0 | ||||
| - click=8.1.7=unix_pyh707e725_0 | ||||
| - colorama=0.4.6=pyhd8ed1ab_0 | ||||
| - comm=0.2.2=pyhd8ed1ab_0 | ||||
| - contourpy=1.3.0=py311hd18a35c_1 | ||||
| - cookiecutter=2.6.0=pyhca7485f_0 | ||||
| - cuda-cccl_linux-64=12.6.37=ha770c72_0 | ||||
| - cuda-crt-dev_linux-64=12.6.68=ha770c72_0 | ||||
| - cuda-crt-tools=12.6.68=ha770c72_0 | ||||
| - cuda-cudart=12.6.68=h5888daf_0 | ||||
| - cuda-cudart-dev=12.6.68=h5888daf_0 | ||||
| - cuda-cudart-dev_linux-64=12.6.68=h3f2d84a_0 | ||||
| - cuda-cudart-static=12.6.68=h5888daf_0 | ||||
| - cuda-cudart-static_linux-64=12.6.68=h3f2d84a_0 | ||||
| - cuda-cudart_linux-64=12.6.68=h3f2d84a_0 | ||||
| - cuda-cupti=12.6.68=h5888daf_0 | ||||
| - cuda-driver-dev_linux-64=12.6.68=h3f2d84a_0 | ||||
| - cuda-nvcc=12.6.68=hcdd1206_0 | ||||
| - cuda-nvcc-dev_linux-64=12.6.68=he91c749_0 | ||||
| - cuda-nvcc-impl=12.6.68=h85509e4_0 | ||||
| - cuda-nvcc-tools=12.6.68=he02047a_0 | ||||
| - cuda-nvcc_linux-64=12.6.68=h8a487aa_0 | ||||
| - cuda-nvrtc=12.6.68=h5888daf_0 | ||||
| - cuda-nvtx=12.6.68=h5888daf_0 | ||||
| - cuda-nvvm-dev_linux-64=12.6.68=ha770c72_0 | ||||
| - cuda-nvvm-impl=12.6.68=he02047a_0 | ||||
| - cuda-nvvm-tools=12.6.68=he02047a_0 | ||||
| - cuda-version=12.6=h7480c83_3 | ||||
| - cudnn=9.2.1.18=hbc370b7_0 | ||||
| - cycler=0.12.1=pyhd8ed1ab_0 | ||||
| - datasets=2.21.0=pyhd8ed1ab_0 | ||||
| - dav1d=1.2.1=hd590300_0 | ||||
| - dbus=1.13.6=h5008d03_3 | ||||
| - debugpy=1.8.5=py311hfdbb021_1 | ||||
| - decorator=5.1.1=pyhd8ed1ab_0 | ||||
| - dill=0.3.8=pyhd8ed1ab_0 | ||||
| - double-conversion=3.3.0=h59595ed_0 | ||||
| - etils=1.9.4=pyhd8ed1ab_0 | ||||
| - evaluate=0.4.1=pyhd8ed1ab_0 | ||||
| - exceptiongroup=1.2.2=pyhd8ed1ab_0 | ||||
| - executing=2.1.0=pyhd8ed1ab_0 | ||||
| - expat=2.6.3=h5888daf_0 | ||||
| - filelock=3.16.0=pyhd8ed1ab_0 | ||||
| - flax=0.9.0=pyhd8ed1ab_0 | ||||
| - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 | ||||
| - font-ttf-inconsolata=3.000=h77eed37_0 | ||||
| - font-ttf-source-code-pro=2.038=h77eed37_0 | ||||
| - font-ttf-ubuntu=0.83=h77eed37_2 | ||||
| - fontconfig=2.14.2=h14ed4e7_0 | ||||
| - fonts-conda-ecosystem=1=0 | ||||
| - fonts-conda-forge=1=0 | ||||
| - fonttools=4.53.1=py311h9ecbd09_1 | ||||
| - freetype=2.12.1=h267a509_2 | ||||
| - frozenlist=1.4.1=py311h9ecbd09_1 | ||||
| - fsspec=2024.5.0=pyhff2d567_0 | ||||
| - gcc_impl_linux-64=13.3.0=hfea6d02_1 | ||||
| - gcc_linux-64=13.3.0=hc28eda2_2 | ||||
| - gflags=2.2.2=he1b5a44_1004 | ||||
| - glog=0.7.1=hbabe93e_0 | ||||
| - graphite2=1.3.13=h59595ed_1003 | ||||
| - gxx_impl_linux-64=13.3.0=hdbfa832_1 | ||||
| - gxx_linux-64=13.3.0=h6834431_2 | ||||
| - h2=4.1.0=pyhd8ed1ab_0 | ||||
| - harfbuzz=9.0.0=hda332d3_1 | ||||
| - hpack=4.0.0=pyh9f0ad1d_0 | ||||
| - huggingface_hub=0.24.6=pyhd8ed1ab_0 | ||||
| - hyperframe=6.0.1=pyhd8ed1ab_0 | ||||
| - icu=75.1=he02047a_0 | ||||
| - idna=3.8=pyhd8ed1ab_0 | ||||
| - importlib-metadata=8.4.0=pyha770c72_0 | ||||
| - importlib_metadata=8.4.0=hd8ed1ab_0 | ||||
| - importlib_resources=6.4.5=pyhd8ed1ab_0 | ||||
| - ipykernel=6.29.5=pyh3099207_0 | ||||
| - ipython=8.27.0=pyh707e725_0 | ||||
| - jax=0.4.31=pyhd8ed1ab_1 | ||||
| - jaxlib=0.4.31=cuda120py311hd88f13b_201 | ||||
| - jedi=0.19.1=pyhd8ed1ab_0 | ||||
| - jinja2=3.1.4=pyhd8ed1ab_0 | ||||
| - joblib=1.4.2=pyhd8ed1ab_0 | ||||
| - jsonschema=4.23.0=pyhd8ed1ab_0 | ||||
| - jsonschema-specifications=2023.12.1=pyhd8ed1ab_0 | ||||
| - jupyter_client=8.6.2=pyhd8ed1ab_0 | ||||
| - jupyter_core=5.7.2=py311h38be061_0 | ||||
| - jupytext=1.16.4=pyh80e38bb_0 | ||||
| - kernel-headers_linux-64=3.10.0=h4a8ded7_16 | ||||
| - keyutils=1.6.1=h166bdaf_0 | ||||
| - kiwisolver=1.4.7=py311hd18a35c_0 | ||||
| - krb5=1.21.3=h659f571_0 | ||||
| - lcms2=2.16=hb7c19ff_0 | ||||
| - ld_impl_linux-64=2.40=hf3520f5_7 | ||||
| - lerc=4.0.0=h27087fc_0 | ||||
| - libabseil=20240116.2=cxx17_he02047a_1 | ||||
| - libarrow=17.0.0=h8d2e343_13_cpu | ||||
| - libarrow-acero=17.0.0=h5888daf_13_cpu | ||||
| - libarrow-dataset=17.0.0=h5888daf_13_cpu | ||||
| - libarrow-substrait=17.0.0=hf54134d_13_cpu | ||||
| - libavif16=1.1.1=h104a339_1 | ||||
| - libblas=3.9.0=23_linux64_openblas | ||||
| - libbrotlicommon=1.1.0=hb9d3cd8_2 | ||||
| - libbrotlidec=1.1.0=hb9d3cd8_2 | ||||
| - libbrotlienc=1.1.0=hb9d3cd8_2 | ||||
| - libcblas=3.9.0=23_linux64_openblas | ||||
| - libclang-cpp18.1=18.1.8=default_hf981a13_4 | ||||
| - libclang13=18.1.8=default_h9def88c_4 | ||||
| - libcrc32c=1.1.2=h9c3ff4c_0 | ||||
| - libcublas=12.6.1.4=h5888daf_0 | ||||
| - libcufft=11.2.6.59=h5888daf_0 | ||||
| - libcups=2.3.3=h4637d8d_4 | ||||
| - libcurand=10.3.7.68=h5888daf_0 | ||||
| - libcurl=8.9.1=hdb1bdb2_0 | ||||
| - libcusolver=11.6.4.69=h5888daf_0 | ||||
| - libcusparse=12.5.3.3=h5888daf_0 | ||||
| - libdeflate=1.21=h4bc722e_0 | ||||
| - libdrm=2.4.123=hb9d3cd8_0 | ||||
| - libedit=3.1.20191231=he28a2e2_2 | ||||
| - libegl=1.7.0=ha4b6fd6_0 | ||||
| - libev=4.33=hd590300_2 | ||||
| - libevent=2.1.12=hf998b51_1 | ||||
| - libexpat=2.6.3=h5888daf_0 | ||||
| - libffi=3.4.2=h7f98852_5 | ||||
| - libgcc=14.1.0=h77fa898_1 | ||||
| - libgcc-devel_linux-64=13.3.0=h84ea5a7_101 | ||||
| - libgcc-ng=14.1.0=h69a702a_1 | ||||
| - libgfortran=14.1.0=h69a702a_1 | ||||
| - libgfortran-ng=14.1.0=h69a702a_1 | ||||
| - libgfortran5=14.1.0=hc5f4f2c_1 | ||||
| - libgl=1.7.0=ha4b6fd6_0 | ||||
| - libglib=2.80.3=h315aac3_2 | ||||
| - libglvnd=1.7.0=ha4b6fd6_0 | ||||
| - libglx=1.7.0=ha4b6fd6_0 | ||||
| - libgomp=14.1.0=h77fa898_1 | ||||
| - libgoogle-cloud=2.28.0=h26d7fe4_0 | ||||
| - libgoogle-cloud-storage=2.28.0=ha262f82_0 | ||||
| - libgrpc=1.62.2=h15f2491_0 | ||||
| - libiconv=1.17=hd590300_2 | ||||
| - libjpeg-turbo=3.0.0=hd590300_1 | ||||
| - liblapack=3.9.0=23_linux64_openblas | ||||
| - libllvm18=18.1.8=h8b73ec9_2 | ||||
| - libnghttp2=1.58.0=h47da74e_1 | ||||
| - libnsl=2.0.1=hd590300_0 | ||||
| - libnvjitlink=12.6.68=h5888daf_0 | ||||
| - libopenblas=0.3.27=pthreads_hac2b453_1 | ||||
| - libparquet=17.0.0=h39682fd_13_cpu | ||||
| - libpciaccess=0.18=hd590300_0 | ||||
| - libpng=1.6.43=h2797004_0 | ||||
| - libpq=16.4=h2d7952a_1 | ||||
| - libprotobuf=4.25.3=h08a7969_0 | ||||
| - libre2-11=2023.09.01=h5a48ba9_2 | ||||
| - libsanitizer=13.3.0=heb74ff8_1 | ||||
| - libsodium=1.0.20=h4ab18f5_0 | ||||
| - libsqlite=3.46.1=hadc24fc_0 | ||||
| - libssh2=1.11.0=h0841786_0 | ||||
| - libstdcxx=14.1.0=hc0a3c3a_1 | ||||
| - libstdcxx-devel_linux-64=13.3.0=h84ea5a7_101 | ||||
| - libstdcxx-ng=14.1.0=h4852527_1 | ||||
| - libthrift=0.20.0=h0e7cc3e_1 | ||||
| - libtiff=4.6.0=h46a8edc_4 | ||||
| - libutf8proc=2.8.0=h166bdaf_0 | ||||
| - libuuid=2.38.1=h0b41bf4_0 | ||||
| - libwebp-base=1.4.0=hd590300_0 | ||||
| - libxcb=1.16=hb9d3cd8_1 | ||||
| - libxcrypt=4.4.36=hd590300_1 | ||||
| - libxkbcommon=1.7.0=h2c5496b_1 | ||||
| - libxml2=2.12.7=he7c6b58_4 | ||||
| - libxslt=1.1.39=h76b75d6_0 | ||||
| - libzlib=1.3.1=h4ab18f5_1 | ||||
| - lxml=5.3.0=py311hcfaa980_1 | ||||
| - lz4-c=1.9.4=hcb278e6_0 | ||||
| - markdown-it-py=3.0.0=pyhd8ed1ab_0 | ||||
| - markupsafe=2.1.5=py311h9ecbd09_1 | ||||
| - matplotlib=3.9.2=py311h38be061_0 | ||||
| - matplotlib-base=3.9.2=py311h74b4f7c_0 | ||||
| - matplotlib-inline=0.1.7=pyhd8ed1ab_0 | ||||
| - mdit-py-plugins=0.4.1=pyhd8ed1ab_0 | ||||
| - mdurl=0.1.2=pyhd8ed1ab_0 | ||||
| - ml_dtypes=0.4.0=py311h7db5c69_2 | ||||
| - msgpack-python=1.0.8=py311hd18a35c_1 | ||||
| - multidict=6.1.0=py311h9ecbd09_0 | ||||
| - multiprocess=0.70.16=py311h9ecbd09_1 | ||||
| - munkres=1.1.4=pyh9f0ad1d_0 | ||||
| - mysql-common=9.0.1=h70512c7_0 | ||||
| - mysql-libs=9.0.1=ha479ceb_0 | ||||
| - nbformat=5.10.4=pyhd8ed1ab_0 | ||||
| - nccl=2.22.3.1=hbc370b7_1 | ||||
| - ncurses=6.5=he02047a_1 | ||||
| - nest-asyncio=1.6.0=pyhd8ed1ab_0 | ||||
| - nltk=3.9.1=pyhd8ed1ab_0 | ||||
| - numpy=1.26.4=py311h64a7726_0 | ||||
| - openjpeg=2.5.2=h488ebb8_0 | ||||
| - openssl=3.3.2=hb9d3cd8_0 | ||||
| - opt-einsum=3.3.0=hd8ed1ab_2 | ||||
| - opt_einsum=3.3.0=pyhc1e730c_2 | ||||
| - optax=0.2.2=pyhd8ed1ab_1 | ||||
| - orbax-checkpoint=0.4.4=pyhd8ed1ab_0 | ||||
| - orc=2.0.2=h669347b_0 | ||||
| - packaging=24.1=pyhd8ed1ab_0 | ||||
| - pandas=2.2.2=py311h14de704_1 | ||||
| - parso=0.8.4=pyhd8ed1ab_0 | ||||
| - pcre2=10.44=hba22ea6_2 | ||||
| - pexpect=4.9.0=pyhd8ed1ab_0 | ||||
| - pickleshare=0.7.5=py_1003 | ||||
| - pillow=10.4.0=py311h82a398c_0 | ||||
| - pip=24.2=pyh8b19718_1 | ||||
| - pixman=0.43.2=h59595ed_0 | ||||
| - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1 | ||||
| - platformdirs=4.3.2=pyhd8ed1ab_0 | ||||
| - portalocker=2.10.1=py311h38be061_0 | ||||
| - prompt-toolkit=3.0.47=pyha770c72_0 | ||||
| - protobuf=4.25.3=py311hbffca5d_1 | ||||
| - psutil=6.0.0=py311h9ecbd09_1 | ||||
| - pthread-stubs=0.4=h36c2ea0_1001 | ||||
| - ptyprocess=0.7.0=pyhd3deb0d_0 | ||||
| - pure_eval=0.2.3=pyhd8ed1ab_0 | ||||
| - pyarrow=17.0.0=py311hbd00459_1 | ||||
| - pyarrow-core=17.0.0=py311h4510849_1_cpu | ||||
| - pyarrow-hotfix=0.6=pyhd8ed1ab_0 | ||||
| - pybind11-abi=4=hd8ed1ab_3 | ||||
| - pycparser=2.22=pyhd8ed1ab_0 | ||||
| - pygments=2.18.0=pyhd8ed1ab_0 | ||||
| - pyparsing=3.1.4=pyhd8ed1ab_0 | ||||
| - pyside6=6.7.2=py311hba19f1e_2 | ||||
| - pysocks=1.7.1=pyha2e5f31_6 | ||||
| - python=3.11.9=hb806964_0_cpython | ||||
| - python-dateutil=2.9.0=pyhd8ed1ab_0 | ||||
| - python-fastjsonschema=2.20.0=pyhd8ed1ab_0 | ||||
| - python-slugify=8.0.4=pyhd8ed1ab_0 | ||||
| - python-tzdata=2024.1=pyhd8ed1ab_0 | ||||
| - python-xxhash=3.5.0=py311h9ecbd09_1 | ||||
| - python_abi=3.11=5_cp311 | ||||
| - pytz=2024.1=pyhd8ed1ab_0 | ||||
| - pyyaml=6.0.2=py311h9ecbd09_1 | ||||
| - pyzmq=26.2.0=py311h7deb3e3_2 | ||||
| - qhull=2020.2=h434a139_5 | ||||
| - qt6-main=6.7.2=hb12f9c5_5 | ||||
| - rav1e=0.6.6=he8a937b_2 | ||||
| - re2=2023.09.01=h7f4b329_2 | ||||
| - readline=8.2=h8228510_1 | ||||
| - referencing=0.35.1=pyhd8ed1ab_0 | ||||
| - regex=2024.7.24=py311h9ecbd09_1 | ||||
| - requests=2.32.3=pyhd8ed1ab_0 | ||||
| - responses=0.18.0=pyhd8ed1ab_0 | ||||
| - rich=13.7.1=pyhd8ed1ab_0 | ||||
| - rpds-py=0.20.0=py311h9e33e62_1 | ||||
| - s2n=1.5.2=h7b32b05_0 | ||||
| - sacrebleu=2.4.1=pyhd8ed1ab_1 | ||||
| - safetensors=0.4.5=py311h9e33e62_0 | ||||
| - scipy=1.14.1=py311he1f765f_0 | ||||
| - setuptools=73.0.1=pyhd8ed1ab_0 | ||||
| - six=1.16.0=pyh6c4a22f_0 | ||||
| - snappy=1.2.1=ha2e4443_0 | ||||
| - stack_data=0.6.2=pyhd8ed1ab_0 | ||||
| - svt-av1=2.2.1=h5888daf_0 | ||||
| - sysroot_linux-64=2.17=h4a8ded7_16 | ||||
| - tabulate=0.9.0=pyhd8ed1ab_1 | ||||
| - tensorstore=0.1.62=py311he109767_1 | ||||
| - text-unidecode=1.3=pyhd8ed1ab_1 | ||||
| - tk=8.6.13=noxft_h4845f30_101 | ||||
| - tokenizers=0.19.1=py311h6640629_0 | ||||
| - tomli=2.0.1=pyhd8ed1ab_0 | ||||
| - toolz=0.12.1=pyhd8ed1ab_0 | ||||
| - tornado=6.4.1=py311h9ecbd09_1 | ||||
| - tqdm=4.66.5=pyhd8ed1ab_0 | ||||
| - traitlets=5.14.3=pyhd8ed1ab_0 | ||||
| - transformers=4.41.2=pyhd8ed1ab_0 | ||||
| - types-python-dateutil=2.9.0.20240906=pyhd8ed1ab_0 | ||||
| - typing=3.10.0.0=pyhd8ed1ab_1 | ||||
| - typing-extensions=4.12.2=hd8ed1ab_0 | ||||
| - typing_extensions=4.12.2=pyha770c72_0 | ||||
| - tzdata=2024a=h8827d51_1 | ||||
| - urllib3=2.2.2=pyhd8ed1ab_1 | ||||
| - wayland=1.23.1=h3e06ad9_0 | ||||
| - wcwidth=0.2.13=pyhd8ed1ab_0 | ||||
| - wheel=0.44.0=pyhd8ed1ab_0 | ||||
| - xcb-util=0.4.1=hb711507_2 | ||||
| - xcb-util-cursor=0.1.4=h4ab18f5_2 | ||||
| - xcb-util-image=0.4.0=hb711507_2 | ||||
| - xcb-util-keysyms=0.4.1=hb711507_0 | ||||
| - xcb-util-renderutil=0.3.10=hb711507_0 | ||||
| - xcb-util-wm=0.4.2=hb711507_0 | ||||
| - xkeyboard-config=2.42=h4ab18f5_0 | ||||
| - xorg-fixesproto=5.0=h7f98852_1002 | ||||
| - xorg-inputproto=2.3.2=h7f98852_1002 | ||||
| - xorg-kbproto=1.0.7=h7f98852_1002 | ||||
| - xorg-libice=1.1.1=hd590300_0 | ||||
| - xorg-libsm=1.2.4=h7391055_0 | ||||
| - xorg-libx11=1.8.9=hb711507_1 | ||||
| - xorg-libxau=1.0.11=hd590300_0 | ||||
| - xorg-libxdmcp=1.1.3=h7f98852_0 | ||||
| - xorg-libxext=1.3.4=h0b41bf4_2 | ||||
| - xorg-libxfixes=5.0.3=h7f98852_1004 | ||||
| - xorg-libxi=1.7.10=h4bc722e_1 | ||||
| - xorg-libxrender=0.9.11=hd590300_0 | ||||
| - xorg-libxtst=1.2.5=h4bc722e_0 | ||||
| - xorg-libxxf86vm=1.1.5=h4bc722e_1 | ||||
| - xorg-recordproto=1.14.2=h7f98852_1002 | ||||
| - xorg-renderproto=0.11.1=h7f98852_1002 | ||||
| - xorg-xextproto=7.3.0=h0b41bf4_1003 | ||||
| - xorg-xproto=7.0.31=h7f98852_1007 | ||||
| - xxhash=0.8.2=hd590300_0 | ||||
| - xz=5.2.6=h166bdaf_0 | ||||
| - yaml=0.2.5=h7f98852_2 | ||||
| - yarl=1.11.1=py311h9ecbd09_0 | ||||
| - zeromq=4.3.5=ha4adb4c_5 | ||||
| - zipp=3.20.1=pyhd8ed1ab_0 | ||||
| - zlib=1.3.1=h4ab18f5_1 | ||||
| - zstandard=0.23.0=py311hbc35293_1 | ||||
| - zstd=1.5.6=ha6fb4c9_0 | ||||
| 
 | ||||
|  | @ -0,0 +1,620 @@ | |||
| # --- | ||||
| # jupyter: | ||||
| #   jupytext: | ||||
| #     formats: ipynb,py:percent | ||||
| #     text_representation: | ||||
| #       extension: .py | ||||
| #       format_name: percent | ||||
| #       format_version: '1.3' | ||||
| #       jupytext_version: 1.16.4 | ||||
| #   kernelspec: | ||||
| #     display_name: jax | ||||
| #     language: python | ||||
| #     name: python3 | ||||
| # --- | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # T5 implementation using jax | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # ## import | ||||
| 
 | ||||
| # %% [raw] | ||||
| # import json | ||||
| # import logging | ||||
| # import math | ||||
| # import os | ||||
| # import sys | ||||
| # import time | ||||
| # from dataclasses import asdict, dataclass, field | ||||
| # from enum import Enum | ||||
| # from functools import partial | ||||
| # from pathlib import Path | ||||
| # from typing import Callable, Optional | ||||
| # | ||||
| # import datasets | ||||
| # import evaluate | ||||
| # import jax | ||||
| # import jax.numpy as jnp | ||||
| # import nltk  # Here to have a nice missing dependency error message early on | ||||
| # import numpy as np | ||||
| # import optax | ||||
| # from datasets import Dataset, load_dataset | ||||
| # from filelock import FileLock | ||||
| # from flax import jax_utils, traverse_util | ||||
| # from flax.jax_utils import pad_shard_unpad, unreplicate | ||||
| # from flax.training import train_state | ||||
| # from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| # from tqdm import tqdm | ||||
| # | ||||
| # import transformers | ||||
| # from transformers import ( | ||||
| #     CONFIG_MAPPING, | ||||
| #     FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | ||||
| #     AutoConfig, | ||||
| #     AutoTokenizer, | ||||
| #     FlaxAutoModelForSeq2SeqLM, | ||||
| #     HfArgumentParser, | ||||
| #     is_tensorboard_available, | ||||
| # ) | ||||
| # from transformers.utils import is_offline_mode, send_example_telemetry | ||||
| # | ||||
| # | ||||
| # logger = logging.getLogger(__name__) | ||||
| # | ||||
| # try: | ||||
| #     nltk.data.find("tokenizers/punkt") | ||||
| # except (LookupError, OSError): | ||||
| #     if is_offline_mode(): | ||||
| #         raise LookupError( | ||||
| #             "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" | ||||
| #         ) | ||||
| #     with FileLock(".lock") as lock: | ||||
| #         nltk.download("punkt", quiet=True) | ||||
| # | ||||
| # | ||||
| # MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) | ||||
| # MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import optax | ||||
| import numpy as np | ||||
| from functools import partial | ||||
| from typing import Callable, Optional | ||||
| import math | ||||
| 
 | ||||
| # jax.config.update("jax_default_matmul_precision", "tensorfloat32") | ||||
| jax.config.update("jax_default_matmul_precision", "high") | ||||
| 
 | ||||
| jax.config.update("jax_enable_x64", False) | ||||
| 
 | ||||
| 
 | ||||
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig | ||||
| 
 | ||||
| 
 | ||||
| import datasets | ||||
| from datasets import Dataset, load_dataset | ||||
| import evaluate | ||||
| from tqdm import tqdm | ||||
| from datasets import load_from_disk | ||||
| 
 | ||||
| 
 | ||||
| import nltk  # Here to have a nice missing dependency error message early on | ||||
| 
 | ||||
| from flax import jax_utils, traverse_util | ||||
| from flax.jax_utils import pad_shard_unpad, unreplicate | ||||
| from flax.training import train_state | ||||
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| 
 | ||||
| 
 | ||||
| import time | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| import os | ||||
| os.environ['XLA_FLAGS'] = ( | ||||
|     '--xla_gpu_enable_triton_softmax_fusion=True ' | ||||
|     '--xla_gpu_triton_gemm_any=True ' | ||||
| ) | ||||
| 
 | ||||
| os.environ.update({ | ||||
|     "NCCL_LL128_BUFFSIZE": "-2", | ||||
|     "NCCL_LL_BUFFSIZE": "-2", | ||||
|     "NCCL_PROTO": "SIMPLE,LL,LL128", | ||||
|  }) | ||||
| 
 | ||||
| # %% | ||||
| from jax.lib import xla_bridge | ||||
| print(xla_bridge.get_backend().platform) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # nltk.download('punkt') | ||||
| try: | ||||
|     nltk.data.find("tokenizers/punkt") | ||||
| except (LookupError, OSError): | ||||
|     if is_offline_mode(): | ||||
|         raise LookupError( | ||||
|             "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" | ||||
|         ) | ||||
|     with FileLock(".lock") as lock: | ||||
|         nltk.download("punkt", quiet=True) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # ## Prepare datasets | ||||
| 
 | ||||
| # %% | ||||
| # load model | ||||
| model_name_or_path = "t5-small"  # Replace with your specific model name | ||||
| 
 | ||||
| # Load configuration | ||||
| config = AutoConfig.from_pretrained(model_name_or_path) | ||||
| 
 | ||||
| # Load model | ||||
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained( | ||||
|     model_name_or_path | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) | ||||
| shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # Path to saved combined_dataset | ||||
| file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' | ||||
| save_path = 't5_80_1' | ||||
| # file_path = 'combined_data' | ||||
| split_datasets = load_from_disk(file_path) | ||||
| 
 | ||||
| # prepare tokenizer | ||||
| from transformers import T5TokenizerFast | ||||
| tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) | ||||
| # Define additional special tokens | ||||
| additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"] | ||||
| # Add the additional special tokens to the tokenizer | ||||
| tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) | ||||
| 
 | ||||
| max_length = 86 | ||||
| 
 | ||||
| # In Flax, for seq2seq models we need to pass `decoder_input_ids` | ||||
| # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here | ||||
| # for that dynamically import the `shift_tokens_right` function from the model file | ||||
| 
 | ||||
| 
 | ||||
| # given a dataset entry, run it through the tokenizer | ||||
| # Setting padding="max_length" as we need fixed length inputs for jitted functions | ||||
| def preprocess_function(example): | ||||
|     input = example['input'] | ||||
|     target = example['output'] | ||||
|     # text_target sets the corresponding label to inputs | ||||
|     # there is no need to create a separate 'labels' | ||||
|     model_inputs = tokenizer( | ||||
|         input, | ||||
|         text_target=target,  | ||||
|         max_length=max_length, | ||||
|         padding="max_length", | ||||
|         truncation=True, | ||||
|         return_tensors="np" | ||||
|     ) | ||||
|     labels = tokenizer( | ||||
|         input, | ||||
|         text_target=target,  | ||||
|         max_length=max_length, | ||||
|         padding="max_length", | ||||
|         truncation=True, | ||||
|         return_tensors="np" | ||||
|     ) | ||||
| 
 | ||||
|     model_inputs["labels"] = labels["input_ids"] | ||||
|     decoder_input_ids = shift_tokens_right_fn( | ||||
|         labels["input_ids"], config.pad_token_id, config.decoder_start_token_id | ||||
|     ) | ||||
|     model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) | ||||
| 
 | ||||
|     # We need decoder_attention_mask so we can ignore pad tokens from loss | ||||
|     model_inputs["decoder_attention_mask"] = labels["attention_mask"] | ||||
| 
 | ||||
|     return model_inputs | ||||
| 
 | ||||
| # map maps function to each "row" in the dataset | ||||
| # aka the data in the immediate nesting | ||||
| tokenized_datasets = split_datasets.map( | ||||
|     preprocess_function, | ||||
|     batched=True, | ||||
|     num_proc=1, | ||||
|     remove_columns=split_datasets["train"].column_names, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| tokenized_datasets | ||||
| 
 | ||||
| # %% | ||||
| train_dataset = tokenized_datasets["train"] | ||||
| eval_dataset = tokenized_datasets["validation"] | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): | ||||
|     """ | ||||
|     Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, | ||||
|     and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. | ||||
|     """ | ||||
|     if shuffle: | ||||
|         batch_idx = jax.random.permutation(rng, len(dataset)) | ||||
|         batch_idx = np.asarray(batch_idx) | ||||
|     else: | ||||
|         batch_idx = np.arange(len(dataset)) | ||||
| 
 | ||||
|     if drop_last: | ||||
|         steps_per_epoch = len(dataset) // batch_size | ||||
|         batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch. | ||||
|         batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | ||||
|     else: | ||||
|         steps_per_epoch = math.ceil(len(dataset) / batch_size) | ||||
|         batch_idx = np.array_split(batch_idx, steps_per_epoch) | ||||
| 
 | ||||
|     for idx in batch_idx: | ||||
|         batch = dataset[idx] | ||||
|         batch = {k: np.array(v) for k, v in batch.items()} | ||||
| 
 | ||||
|         yield batch | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # Now we have model inputs in terms of the variable tokenized_datasets | ||||
| 
 | ||||
| # %% | ||||
| # metric | ||||
| metric = evaluate.load("sacrebleu") | ||||
| 
 | ||||
| def postprocess_text(preds, labels): | ||||
|     preds = [pred.strip() for pred in preds] | ||||
|     labels = [label.strip() for label in labels] | ||||
| 
 | ||||
|     # rougeLSum expects newline after each sentence | ||||
|     preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | ||||
|     labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] | ||||
| 
 | ||||
|     return preds, labels | ||||
| 
 | ||||
| # def compute_metrics(preds, labels): | ||||
| #     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | ||||
| #     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | ||||
| #  | ||||
| #     # Some simple post-processing | ||||
| #     decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) | ||||
| #  | ||||
| #     result = metric.compute(predictions=decoded_preds, references=decoded_labels) | ||||
| #     result = {k: round(v * 100, 4) for k, v in result.items()} | ||||
| #     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] | ||||
| #     result["gen_len"] = np.mean(prediction_lens) | ||||
| #     return result | ||||
| 
 | ||||
| def compute_metrics(preds, labels): | ||||
|     # In case the model returns more than the prediction logits | ||||
|     if isinstance(preds, tuple): | ||||
|         preds = preds[0] | ||||
| 
 | ||||
|     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | ||||
| 
 | ||||
|     # Replace -100s in the labels as we can't decode them | ||||
|     labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | ||||
|     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | ||||
| 
 | ||||
|     # Some simple post-processing | ||||
|     decoded_preds = [pred.strip() for pred in decoded_preds] | ||||
|     decoded_labels = [[label.strip()] for label in decoded_labels] | ||||
| 
 | ||||
|     result = metric.compute(predictions=decoded_preds, references=decoded_labels) | ||||
|     return {"bleu": result["score"]} | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Model | ||||
| 
 | ||||
| # %% | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 80 | ||||
| batch_size = 96 | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
| per_device_eval_batch_size = batch_size | ||||
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||||
| steps_per_epoch = len(train_dataset) // train_batch_size | ||||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 5e-5 | ||||
| 
 | ||||
| weight_decay = 0.0 | ||||
| adam_beta1 = 0.9 | ||||
| adam_beta2 = 0.999 | ||||
| adam_epsilon = 1e-8 | ||||
| label_smoothing_factor = 0.0 | ||||
| 
 | ||||
| num_beams = 1 | ||||
| val_max_target_length = None | ||||
| 
 | ||||
| predict_with_generate = True | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # Initialize our training | ||||
| rng = jax.random.PRNGKey(seed) | ||||
| rng, dropout_rng = jax.random.split(rng) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # optimization functions | ||||
| 
 | ||||
| def create_learning_rate_fn( | ||||
|     train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | ||||
| ) -> Callable[[int], jnp.ndarray]: | ||||
|     """Returns a linear warmup, linear_decay learning rate function.""" | ||||
|     steps_per_epoch = train_ds_size // train_batch_size | ||||
|     num_train_steps = steps_per_epoch * num_train_epochs | ||||
|     warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | ||||
|     decay_fn = optax.linear_schedule( | ||||
|         init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | ||||
|     ) | ||||
|     schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | ||||
|     return schedule_fn | ||||
| 
 | ||||
| 
 | ||||
| # Create learning rate schedule | ||||
| linear_decay_lr_schedule_fn = create_learning_rate_fn( | ||||
|     len(train_dataset), | ||||
|     train_batch_size, | ||||
|     num_train_epochs, | ||||
|     warmup_steps, | ||||
|     learning_rate, | ||||
| ) | ||||
| 
 | ||||
| # We use Optax's "masking" functionality to not apply weight decay | ||||
| # to bias and LayerNorm scale parameters. decay_mask_fn returns a | ||||
| # mask boolean with the same structure as the parameters. | ||||
| # The mask is True for parameters that should be decayed. | ||||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|         for layer in flat_params.keys() | ||||
|         if layer_norm_name in "".join(layer).lower() | ||||
|     } | ||||
|     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | ||||
|     return traverse_util.unflatten_dict(flat_mask) | ||||
| 
 | ||||
| # create adam optimizer | ||||
| adamw = optax.adamw( | ||||
|     learning_rate=linear_decay_lr_schedule_fn, | ||||
|     b1=adam_beta1, | ||||
|     b2=adam_beta2, | ||||
|     eps=adam_epsilon, | ||||
|     weight_decay=weight_decay, | ||||
|     mask=decay_mask_fn, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # Training functions | ||||
| class TrainState(train_state.TrainState): | ||||
|     dropout_rng: jnp.ndarray | ||||
| 
 | ||||
|     def replicate(self): | ||||
|         return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) | ||||
| 
 | ||||
| # Ensure model.params is properly initialized (this is just an example) | ||||
| # Normally you would get this from a model initialization call with dummy input | ||||
| params = model.params | ||||
| # Cast parameters to bfloat16 if desired | ||||
| params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) | ||||
| 
 | ||||
| 
 | ||||
| # Setup train state | ||||
| state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) | ||||
| 
 | ||||
| # label smoothed cross entropy | ||||
| def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | ||||
|     """ | ||||
|     The label smoothing implementation is adapted from Flax's official example: | ||||
|     https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 | ||||
|     """ | ||||
|     vocab_size = logits.shape[-1] | ||||
|     confidence = 1.0 - label_smoothing_factor | ||||
|     low_confidence = (1.0 - confidence) / (vocab_size - 1) | ||||
|     normalizing_constant = -( | ||||
|         confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) | ||||
|     ) | ||||
|     soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) | ||||
| 
 | ||||
|     loss = optax.softmax_cross_entropy(logits, soft_labels) | ||||
|     loss = loss - normalizing_constant | ||||
| 
 | ||||
|     # ignore padded tokens from loss | ||||
|     loss = loss * padding_mask | ||||
|     loss = loss.sum() | ||||
|     num_labels = padding_mask.sum() | ||||
|     return loss, num_labels | ||||
| 
 | ||||
| # Define gradient update step fn | ||||
| def train_step(state, batch, label_smoothing_factor=0.0): | ||||
|     dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | ||||
| 
 | ||||
|     def compute_loss(params): | ||||
|         labels = batch.pop("labels") | ||||
|         logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] | ||||
|         loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) | ||||
|         return loss, num_labels | ||||
| 
 | ||||
|     # compute gradients through computational graph | ||||
|     grad_fn = jax.value_and_grad(compute_loss, has_aux=True) | ||||
|     (loss, num_labels), grad = grad_fn(state.params) | ||||
|     num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # true loss = total loss / total samples | ||||
|     loss = jax.lax.psum(loss, "batch") | ||||
|     loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
| 
 | ||||
|     # true grad = total grad / total samples | ||||
|     grad = jax.lax.psum(grad, "batch") | ||||
|     grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) | ||||
|     new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) | ||||
| 
 | ||||
|     metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
|     return new_state, metrics | ||||
| 
 | ||||
| # Define eval fn | ||||
| def eval_step(params, batch, label_smoothing_factor=0.0): | ||||
|     labels = batch.pop("labels") | ||||
|     logits = model(**batch, params=params, train=False)[0] | ||||
| 
 | ||||
|     loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) | ||||
|     num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # true loss = total loss / total samples | ||||
|     loss = jax.lax.psum(loss, "batch") | ||||
|     loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
| 
 | ||||
|     metrics = {"loss": loss} | ||||
|     return metrics | ||||
| 
 | ||||
| # Define generation function | ||||
| max_length = ( | ||||
|     val_max_target_length if val_max_target_length is not None else model.config.max_length | ||||
| ) | ||||
| num_beams = num_beams if num_beams is not None else model.config.num_beams | ||||
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | ||||
| 
 | ||||
| def generate_step(params, batch): | ||||
|     model.params = params | ||||
|     output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) | ||||
|     return output_ids.sequences | ||||
| 
 | ||||
| # Create parallel version of the train and eval step | ||||
| p_train_step = jax.pmap( | ||||
|     partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) | ||||
| ) | ||||
| p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") | ||||
| p_generate_step = jax.pmap(generate_step, "batch") | ||||
| 
 | ||||
| # Replicate the train state on each device | ||||
| state = state.replicate() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| 
 | ||||
| print("***** Running training *****") | ||||
| print(f"  Num examples = {len(train_dataset)}") | ||||
| print(f"  Num Epochs = {num_epochs}") | ||||
| print(f"  Instantaneous batch size per device = {per_device_train_batch_size}") | ||||
| print(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}") | ||||
| print(f"  Total optimization steps = {total_train_steps}") | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| train_time = 0 | ||||
| epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) | ||||
| # epochs = range(num_epochs) | ||||
| for epoch in epochs: | ||||
|     # ======================== Training ================================ | ||||
|     train_start = time.time() | ||||
| 
 | ||||
|     # Create sampling rng | ||||
|     rng, input_rng = jax.random.split(rng) | ||||
|     train_metrics = [] | ||||
| 
 | ||||
|     # Generate an epoch by shuffling sampling indices from the train dataset | ||||
|     train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) | ||||
|     steps_per_epoch = len(train_dataset) // train_batch_size | ||||
|     # train | ||||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|         batch = shard(batch) | ||||
|         state, train_metric = p_train_step(state, batch) | ||||
|         train_metrics.append(train_metric) | ||||
| 
 | ||||
|     train_time += time.time() - train_start | ||||
| 
 | ||||
|     train_metric = unreplicate(train_metric) | ||||
| 
 | ||||
|     epochs.write( | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" | ||||
|         f" {train_metric['learning_rate']})" | ||||
|     ) | ||||
| 
 | ||||
|     # ======================== Evaluating ============================== | ||||
|     eval_metrics = [] | ||||
|     eval_preds = [] | ||||
|     eval_labels = [] | ||||
| 
 | ||||
|     eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) | ||||
|     eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) | ||||
|     for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): | ||||
|         # Model forward | ||||
|         batch = next(eval_loader) | ||||
|         labels = batch["labels"] | ||||
| 
 | ||||
|         metrics = pad_shard_unpad(p_eval_step, static_return=True)( | ||||
|             state.params, batch, min_device_batch=per_device_eval_batch_size | ||||
|         ) | ||||
|         eval_metrics.append(metrics) | ||||
| 
 | ||||
|         # generation | ||||
|         if predict_with_generate: | ||||
|             generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) | ||||
|             eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) | ||||
|             eval_labels.extend(labels) | ||||
| 
 | ||||
|     # normalize eval metrics | ||||
|     eval_metrics = get_metrics(eval_metrics) | ||||
|     eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) | ||||
| 
 | ||||
|     # compute metrics | ||||
|     rouge_desc = "" | ||||
|     if predict_with_generate: | ||||
|         rouge_metrics = compute_metrics(eval_preds, eval_labels) | ||||
|         eval_metrics.update(rouge_metrics) | ||||
|         rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) | ||||
| 
 | ||||
|     # Print metrics and update progress bar | ||||
|     desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" | ||||
|     epochs.write(desc) | ||||
|     epochs.desc = desc | ||||
| 
 | ||||
|     # Save metrics | ||||
|     # if has_tensorboard and jax.process_index() == 0: | ||||
|     #     cur_step = epoch * (len(train_dataset) // train_batch_size) | ||||
|     #     write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) | ||||
| 
 | ||||
|     output_dir = save_path | ||||
|     # save checkpoint after each epoch and push checkpoint to the hub | ||||
|     if jax.process_index() == 0: | ||||
|         params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) | ||||
|         model.save_pretrained(output_dir, params=params) | ||||
|         tokenizer.save_pretrained(output_dir) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # #  | ||||
|  | @ -0,0 +1,386 @@ | |||
| # --- | ||||
| # jupyter: | ||||
| #   jupytext: | ||||
| #     formats: ipynb,py:percent | ||||
| #     text_representation: | ||||
| #       extension: .py | ||||
| #       format_name: percent | ||||
| #       format_version: '1.3' | ||||
| #       jupytext_version: 1.16.4 | ||||
| # --- | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # prediction code | ||||
| # ## import and process test data | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # import libraries | ||||
| import pandas as pd | ||||
| import matplotlib.pyplot as plt | ||||
| 
 | ||||
| from datasets import Dataset, DatasetDict | ||||
| 
 | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import optax | ||||
| import numpy as np | ||||
| from functools import partial | ||||
| from typing import Callable, Optional | ||||
| import math | ||||
| 
 | ||||
| # jax.config.update("jax_default_matmul_precision", "tensorfloat32") | ||||
| jax.config.update("jax_default_matmul_precision", "high") | ||||
| 
 | ||||
| jax.config.update("jax_enable_x64", False) | ||||
| 
 | ||||
| 
 | ||||
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig | ||||
| 
 | ||||
| 
 | ||||
| import datasets | ||||
| from datasets import Dataset, load_dataset | ||||
| import evaluate | ||||
| from tqdm import tqdm | ||||
| from datasets import load_from_disk | ||||
| 
 | ||||
| 
 | ||||
| import nltk  # Here to have a nice missing dependency error message early on | ||||
| 
 | ||||
| from flax import jax_utils, traverse_util | ||||
| from flax.jax_utils import pad_shard_unpad, unreplicate | ||||
| from flax.training import train_state | ||||
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| 
 | ||||
| 
 | ||||
| import time | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # data_path = f"../make_data/select_db/data_mapping_filtered.csv" | ||||
| # data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" | ||||
| data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' | ||||
| # data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' | ||||
| 
 | ||||
| # Ensure to include 'ships_idx' in the fields list | ||||
| fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] | ||||
| 
 | ||||
| # Load the dataset | ||||
| df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) | ||||
| 
 | ||||
| def process_df(df): | ||||
|     output_list = [{ | ||||
|             'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>", | ||||
|             # 'input': f"<DESC>{row['tag_description']}<DESC>", | ||||
|             # 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>", | ||||
|             # 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>", | ||||
|             'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>", | ||||
|             'answer': f"{row['thing']} {row['property']}", | ||||
|             'answer_thing': row['thing'], | ||||
|             'answer_property': row['property'], | ||||
|     } for _, row in df.iterrows()] | ||||
| 
 | ||||
|     return output_list | ||||
| 
 | ||||
| 
 | ||||
| # takes 1 minute to run without batching | ||||
| test_dataset = Dataset.from_list(process_df(df)) | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # ## Load model for attributes | ||||
| 
 | ||||
| # %% | ||||
| # load model | ||||
| model_name_or_path = "t5_80_1"  # Replace with your specific model name | ||||
| 
 | ||||
| # Load configuration | ||||
| config = AutoConfig.from_pretrained(model_name_or_path) | ||||
| 
 | ||||
| # Load model | ||||
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained( | ||||
|     model_name_or_path | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # ## Tokenizer | ||||
| 
 | ||||
| # %% | ||||
| # prepare tokenizer | ||||
| from transformers import T5TokenizerFast | ||||
| tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) | ||||
| # Define additional special tokens | ||||
| additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"] | ||||
| # Add the additional special tokens to the tokenizer | ||||
| tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) | ||||
| 
 | ||||
| max_length = 86 | ||||
| 
 | ||||
| model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) | ||||
| shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") | ||||
| 
 | ||||
| # given a dataset entry, run it through the tokenizer | ||||
| # Setting padding="max_length" as we need fixed length inputs for jitted functions | ||||
| def preprocess_function(example): | ||||
|     input = example['input'] | ||||
|     target = example['output'] | ||||
|     # text_target sets the corresponding label to inputs | ||||
|     # there is no need to create a separate 'labels' | ||||
|     model_inputs = tokenizer( | ||||
|         input, | ||||
|         text_target=target,  | ||||
|         max_length=max_length, | ||||
|         padding="max_length", | ||||
|         truncation=True, | ||||
|         return_tensors="np" | ||||
|     ) | ||||
|     labels = tokenizer( | ||||
|         input, | ||||
|         text_target=target,  | ||||
|         max_length=max_length, | ||||
|         padding="max_length", | ||||
|         truncation=True, | ||||
|         return_tensors="np" | ||||
|     ) | ||||
| 
 | ||||
|     model_inputs["labels"] = labels["input_ids"] | ||||
|     decoder_input_ids = shift_tokens_right_fn( | ||||
|         labels["input_ids"], config.pad_token_id, config.decoder_start_token_id | ||||
|     ) | ||||
|     model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) | ||||
| 
 | ||||
|     # We need decoder_attention_mask so we can ignore pad tokens from loss | ||||
|     model_inputs["decoder_attention_mask"] = labels["attention_mask"] | ||||
| 
 | ||||
|     return model_inputs | ||||
| 
 | ||||
| # map maps function to each "row" in the dataset | ||||
| # aka the data in the immediate nesting | ||||
| test_dataset = test_dataset.map( | ||||
|     preprocess_function, | ||||
|     batched=True, | ||||
|     num_proc=1, | ||||
|     remove_columns=test_dataset.column_names, | ||||
| ) | ||||
| 
 | ||||
| def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): | ||||
|     """ | ||||
|     Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, | ||||
|     and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. | ||||
|     """ | ||||
|     if shuffle: | ||||
|         batch_idx = jax.random.permutation(rng, len(dataset)) | ||||
|         batch_idx = np.asarray(batch_idx) | ||||
|     else: | ||||
|         batch_idx = np.arange(len(dataset)) | ||||
| 
 | ||||
|     if drop_last: | ||||
|         steps_per_epoch = len(dataset) // batch_size | ||||
|         batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch. | ||||
|         batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | ||||
|     else: | ||||
|         steps_per_epoch = math.ceil(len(dataset) / batch_size) | ||||
|         batch_idx = np.array_split(batch_idx, steps_per_epoch) | ||||
| 
 | ||||
|     for idx in batch_idx: | ||||
|         batch = dataset[idx] | ||||
|         batch = {k: np.array(v) for k, v in batch.items()} | ||||
| 
 | ||||
|         yield batch | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Model Training | ||||
| 
 | ||||
| # %% | ||||
| seed = 117 | ||||
| num_epochs = 80 | ||||
| batch_size = 96 | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
| per_device_eval_batch_size = batch_size | ||||
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||||
| steps_per_epoch = len(test_dataset) // train_batch_size | ||||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 5e-5 | ||||
| 
 | ||||
| weight_decay = 0.0 | ||||
| adam_beta1 = 0.9 | ||||
| adam_beta2 = 0.999 | ||||
| adam_epsilon = 1e-8 | ||||
| label_smoothing_factor = 0.0 | ||||
| 
 | ||||
| num_beams = 1 | ||||
| val_max_target_length = None | ||||
| 
 | ||||
| predict_with_generate = True | ||||
| 
 | ||||
| 
 | ||||
| # Initialize our training | ||||
| rng = jax.random.PRNGKey(seed) | ||||
| rng, dropout_rng = jax.random.split(rng) | ||||
| 
 | ||||
| def create_learning_rate_fn( | ||||
|     train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | ||||
| ) -> Callable[[int], jnp.ndarray]: | ||||
|     """Returns a linear warmup, linear_decay learning rate function.""" | ||||
|     steps_per_epoch = train_ds_size // train_batch_size | ||||
|     num_train_steps = steps_per_epoch * num_train_epochs | ||||
|     warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | ||||
|     decay_fn = optax.linear_schedule( | ||||
|         init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | ||||
|     ) | ||||
|     schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | ||||
|     return schedule_fn | ||||
| 
 | ||||
| 
 | ||||
| # Create learning rate schedule | ||||
| linear_decay_lr_schedule_fn = create_learning_rate_fn( | ||||
|     len(test_dataset), | ||||
|     train_batch_size, | ||||
|     num_train_epochs, | ||||
|     warmup_steps, | ||||
|     learning_rate, | ||||
| ) | ||||
| 
 | ||||
| # We use Optax's "masking" functionality to not apply weight decay | ||||
| # to bias and LayerNorm scale parameters. decay_mask_fn returns a | ||||
| # mask boolean with the same structure as the parameters. | ||||
| # The mask is True for parameters that should be decayed. | ||||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|         for layer in flat_params.keys() | ||||
|         if layer_norm_name in "".join(layer).lower() | ||||
|     } | ||||
|     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | ||||
|     return traverse_util.unflatten_dict(flat_mask) | ||||
| 
 | ||||
| # create adam optimizer | ||||
| adamw = optax.adamw( | ||||
|     learning_rate=linear_decay_lr_schedule_fn, | ||||
|     b1=adam_beta1, | ||||
|     b2=adam_beta2, | ||||
|     eps=adam_epsilon, | ||||
|     weight_decay=weight_decay, | ||||
|     mask=decay_mask_fn, | ||||
| ) | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # reload model to prevent leakage of variables | ||||
| # load model | ||||
| model_name_or_path = "t5_80_1"  # Replace with your specific model name | ||||
| 
 | ||||
| # Load configuration | ||||
| config = AutoConfig.from_pretrained(model_name_or_path) | ||||
| 
 | ||||
| # Load model | ||||
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained( | ||||
|     model_name_or_path | ||||
| ) | ||||
| 
 | ||||
| # Training functions | ||||
| class TrainState(train_state.TrainState): | ||||
|     dropout_rng: jnp.ndarray | ||||
| 
 | ||||
|     def replicate(self): | ||||
|         return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) | ||||
| 
 | ||||
| # Ensure model.params is properly initialized (this is just an example) | ||||
| # Normally you would get this from a model initialization call with dummy input | ||||
| params = model.params | ||||
| # Cast parameters to bfloat16 if desired | ||||
| params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) | ||||
| 
 | ||||
| 
 | ||||
| # Setup train state | ||||
| state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # Define generation function | ||||
| max_length = ( | ||||
|     val_max_target_length if val_max_target_length is not None else model.config.max_length | ||||
| ) | ||||
| num_beams = num_beams if num_beams is not None else model.config.num_beams | ||||
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | ||||
| 
 | ||||
| def generate_step(params, batch): | ||||
|     model.params = params | ||||
|     output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) | ||||
|     return output_ids.sequences | ||||
| 
 | ||||
| # Create parallel version of the train and eval step | ||||
| p_generate_step = jax.pmap(generate_step, "batch") | ||||
| 
 | ||||
| # Replicate the train state on each device | ||||
| state = state.replicate() | ||||
| 
 | ||||
| 
 | ||||
| pred_metrics = [] | ||||
| pred_generations = [] | ||||
| pred_labels = [] | ||||
| 
 | ||||
| rng, input_rng = jax.random.split(rng) | ||||
| 
 | ||||
| pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False) | ||||
| pred_steps = math.ceil(len(test_dataset) / eval_batch_size) | ||||
| 
 | ||||
| print("***** Running training *****") | ||||
| print(f"  Num examples = {len(test_dataset)}") | ||||
| print(f"  Num steps = {num_epochs}") | ||||
| print(f"  Instantaneous batch size per device = {per_device_train_batch_size}") | ||||
| print(f"  Total test batch size (w. parallel & distributed) = {train_batch_size}") | ||||
| 
 | ||||
| 
 | ||||
| for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False): | ||||
|     # Model forward | ||||
|     batch = next(pred_loader) | ||||
|     labels = batch["labels"] | ||||
| 
 | ||||
|     # generation | ||||
|     if predict_with_generate: | ||||
|         generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) | ||||
|         pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) | ||||
|         pred_labels.extend(labels) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # Print metrics | ||||
| # desc = f"Predict Loss: {pred_metrics['loss']})" | ||||
| # print(desc) | ||||
| 
 | ||||
| # %% | ||||
| # save predictions to parquet | ||||
| 
 | ||||
| # decode prediction labels | ||||
| def decode_preds(preds): | ||||
|     # In case the model returns more than the prediction logits | ||||
|     if isinstance(preds, tuple): | ||||
|         preds = preds[0] | ||||
| 
 | ||||
|     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | ||||
| 
 | ||||
|     decoded_preds = [pred for pred in decoded_preds] | ||||
| 
 | ||||
|     return decoded_preds | ||||
| 
 | ||||
| 
 | ||||
| # Convert the list to a Pandas DataFrame | ||||
| df = pd.DataFrame(decode_preds(pred_labels)) | ||||
| 
 | ||||
| # Save the DataFrame as a Parquet file (using pyarrow or fastparquet) | ||||
| df.to_parquet("exports/output_file.parquet", engine="pyarrow")  # or use engine="fastparquet" | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
|  | @ -0,0 +1,624 @@ | |||
| # --- | ||||
| # jupyter: | ||||
| #   jupytext: | ||||
| #     formats: ipynb,py:percent | ||||
| #     text_representation: | ||||
| #       extension: .py | ||||
| #       format_name: percent | ||||
| #       format_version: '1.3' | ||||
| #       jupytext_version: 1.16.4 | ||||
| #   kernelspec: | ||||
| #     display_name: jax | ||||
| #     language: python | ||||
| #     name: python3 | ||||
| # --- | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # T5 implementation using jax | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # ## import | ||||
| 
 | ||||
| # %% [raw] | ||||
| # import json | ||||
| # import logging | ||||
| # import math | ||||
| # import os | ||||
| # import sys | ||||
| # import time | ||||
| # from dataclasses import asdict, dataclass, field | ||||
| # from enum import Enum | ||||
| # from functools import partial | ||||
| # from pathlib import Path | ||||
| # from typing import Callable, Optional | ||||
| # | ||||
| # import datasets | ||||
| # import evaluate | ||||
| # import jax | ||||
| # import jax.numpy as jnp | ||||
| # import nltk  # Here to have a nice missing dependency error message early on | ||||
| # import numpy as np | ||||
| # import optax | ||||
| # from datasets import Dataset, load_dataset | ||||
| # from filelock import FileLock | ||||
| # from flax import jax_utils, traverse_util | ||||
| # from flax.jax_utils import pad_shard_unpad, unreplicate | ||||
| # from flax.training import train_state | ||||
| # from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| # from tqdm import tqdm | ||||
| # | ||||
| # import transformers | ||||
| # from transformers import ( | ||||
| #     CONFIG_MAPPING, | ||||
| #     FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | ||||
| #     AutoConfig, | ||||
| #     AutoTokenizer, | ||||
| #     FlaxAutoModelForSeq2SeqLM, | ||||
| #     HfArgumentParser, | ||||
| #     is_tensorboard_available, | ||||
| # ) | ||||
| # from transformers.utils import is_offline_mode, send_example_telemetry | ||||
| # | ||||
| # | ||||
| # logger = logging.getLogger(__name__) | ||||
| # | ||||
| # try: | ||||
| #     nltk.data.find("tokenizers/punkt") | ||||
| # except (LookupError, OSError): | ||||
| #     if is_offline_mode(): | ||||
| #         raise LookupError( | ||||
| #             "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" | ||||
| #         ) | ||||
| #     with FileLock(".lock") as lock: | ||||
| #         nltk.download("punkt", quiet=True) | ||||
| # | ||||
| # | ||||
| # MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) | ||||
| # MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import optax | ||||
| import numpy as np | ||||
| from functools import partial | ||||
| from typing import Callable, Optional | ||||
| import math | ||||
| 
 | ||||
| # jax.config.update("jax_default_matmul_precision", "tensorfloat32") | ||||
| jax.config.update("jax_default_matmul_precision", "high") | ||||
| 
 | ||||
| jax.config.update("jax_enable_x64", False) | ||||
| 
 | ||||
| 
 | ||||
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig | ||||
| 
 | ||||
| 
 | ||||
| import datasets | ||||
| from datasets import Dataset, load_dataset | ||||
| import evaluate | ||||
| 
 | ||||
| 
 | ||||
| import nltk  # Here to have a nice missing dependency error message early on | ||||
| 
 | ||||
| from flax import jax_utils, traverse_util | ||||
| from flax.jax_utils import pad_shard_unpad, unreplicate | ||||
| from flax.training import train_state | ||||
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| 
 | ||||
| 
 | ||||
| import time | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| import os | ||||
| os.environ['XLA_FLAGS'] = ( | ||||
|     '--xla_gpu_enable_triton_softmax_fusion=True ' | ||||
|     '--xla_gpu_triton_gemm_any=True ' | ||||
| ) | ||||
| 
 | ||||
| os.environ.update({ | ||||
|     "CUDA_VISIBLE_DEVICES": "0, 1, 2, 3", | ||||
|     "NCCL_LL128_BUFFSIZE": "-2", | ||||
|     "NCCL_LL_BUFFSIZE": "-2", | ||||
|     "NCCL_PROTO": "SIMPLE,LL,LL128", | ||||
|  }) | ||||
| 
 | ||||
| # %% | ||||
| from jax.lib import xla_bridge | ||||
| print(xla_bridge.get_backend().platform) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # nltk.download('punkt') | ||||
| try: | ||||
|     nltk.data.find("tokenizers/punkt") | ||||
| except (LookupError, OSError): | ||||
|     if is_offline_mode(): | ||||
|         raise LookupError( | ||||
|             "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" | ||||
|         ) | ||||
|     with FileLock(".lock") as lock: | ||||
|         nltk.download("punkt", quiet=True) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # ## Prepare datasets | ||||
| 
 | ||||
| # %% | ||||
| # load model | ||||
| model_name_or_path = "t5-small"  # Replace with your specific model name | ||||
| 
 | ||||
| # Load configuration | ||||
| config = AutoConfig.from_pretrained(model_name_or_path) | ||||
| 
 | ||||
| # Load model | ||||
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained( | ||||
|     model_name_or_path | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) | ||||
| shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| from tqdm import tqdm | ||||
| from datasets import load_from_disk | ||||
| # Path to saved combined_dataset | ||||
| file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5' | ||||
| save_path = 't5_80_1_retrieval' | ||||
| # file_path = 'combined_data' | ||||
| split_datasets = load_from_disk(file_path) | ||||
| 
 | ||||
| # prepare tokenizer | ||||
| from transformers import T5TokenizerFast | ||||
| tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True) | ||||
| # Define additional special tokens | ||||
| # additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"] | ||||
| # Define additional special tokens | ||||
| additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>",                                                                                           | ||||
|                              "<CONTEXT>", "<EXAMPLE>", "<INPUT>", "<OUTPUT>"] | ||||
| # Add the additional special tokens to the tokenizer | ||||
| tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) | ||||
| 
 | ||||
| max_length = 300 | ||||
| 
 | ||||
| # In Flax, for seq2seq models we need to pass `decoder_input_ids` | ||||
| # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here | ||||
| # for that dynamically import the `shift_tokens_right` function from the model file | ||||
| 
 | ||||
| 
 | ||||
| # given a dataset entry, run it through the tokenizer | ||||
| # Setting padding="max_length" as we need fixed length inputs for jitted functions | ||||
| def preprocess_function(example): | ||||
|     input = example['input'] | ||||
|     target = example['output'] | ||||
|     # text_target sets the corresponding label to inputs | ||||
|     # there is no need to create a separate 'labels' | ||||
|     model_inputs = tokenizer( | ||||
|         input, | ||||
|         text_target=target,  | ||||
|         max_length=max_length, | ||||
|         padding="max_length", | ||||
|         truncation=True, | ||||
|         return_tensors="np" | ||||
|     ) | ||||
|     labels = tokenizer( | ||||
|         input, | ||||
|         text_target=target,  | ||||
|         max_length=max_length, | ||||
|         padding="max_length", | ||||
|         truncation=True, | ||||
|         return_tensors="np" | ||||
|     ) | ||||
| 
 | ||||
|     model_inputs["labels"] = labels["input_ids"] | ||||
|     decoder_input_ids = shift_tokens_right_fn( | ||||
|         labels["input_ids"], config.pad_token_id, config.decoder_start_token_id | ||||
|     ) | ||||
|     model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) | ||||
| 
 | ||||
|     # We need decoder_attention_mask so we can ignore pad tokens from loss | ||||
|     model_inputs["decoder_attention_mask"] = labels["attention_mask"] | ||||
| 
 | ||||
|     return model_inputs | ||||
| 
 | ||||
| # map maps function to each "row" in the dataset | ||||
| # aka the data in the immediate nesting | ||||
| tokenized_datasets = split_datasets.map( | ||||
|     preprocess_function, | ||||
|     batched=True, | ||||
|     num_proc=1, | ||||
|     remove_columns=split_datasets["train"].column_names, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| tokenized_datasets | ||||
| 
 | ||||
| # %% | ||||
| train_dataset = tokenized_datasets["train"] | ||||
| eval_dataset = tokenized_datasets["validation"] | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): | ||||
|     """ | ||||
|     Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, | ||||
|     and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. | ||||
|     """ | ||||
|     if shuffle: | ||||
|         batch_idx = jax.random.permutation(rng, len(dataset)) | ||||
|         batch_idx = np.asarray(batch_idx) | ||||
|     else: | ||||
|         batch_idx = np.arange(len(dataset)) | ||||
| 
 | ||||
|     if drop_last: | ||||
|         steps_per_epoch = len(dataset) // batch_size | ||||
|         batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch. | ||||
|         batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | ||||
|     else: | ||||
|         steps_per_epoch = math.ceil(len(dataset) / batch_size) | ||||
|         batch_idx = np.array_split(batch_idx, steps_per_epoch) | ||||
| 
 | ||||
|     for idx in batch_idx: | ||||
|         batch = dataset[idx] | ||||
|         batch = {k: np.array(v) for k, v in batch.items()} | ||||
| 
 | ||||
|         yield batch | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # Now we have model inputs in terms of the variable tokenized_datasets | ||||
| 
 | ||||
| # %% | ||||
| # metric | ||||
| metric = evaluate.load("sacrebleu") | ||||
| 
 | ||||
| def postprocess_text(preds, labels): | ||||
|     preds = [pred.strip() for pred in preds] | ||||
|     labels = [label.strip() for label in labels] | ||||
| 
 | ||||
|     # rougeLSum expects newline after each sentence | ||||
|     preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | ||||
|     labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] | ||||
| 
 | ||||
|     return preds, labels | ||||
| 
 | ||||
| # def compute_metrics(preds, labels): | ||||
| #     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | ||||
| #     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | ||||
| #  | ||||
| #     # Some simple post-processing | ||||
| #     decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) | ||||
| #  | ||||
| #     result = metric.compute(predictions=decoded_preds, references=decoded_labels) | ||||
| #     result = {k: round(v * 100, 4) for k, v in result.items()} | ||||
| #     prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] | ||||
| #     result["gen_len"] = np.mean(prediction_lens) | ||||
| #     return result | ||||
| 
 | ||||
| def compute_metrics(preds, labels): | ||||
|     # In case the model returns more than the prediction logits | ||||
|     if isinstance(preds, tuple): | ||||
|         preds = preds[0] | ||||
| 
 | ||||
|     decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | ||||
| 
 | ||||
|     # Replace -100s in the labels as we can't decode them | ||||
|     labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | ||||
|     decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | ||||
| 
 | ||||
|     # Some simple post-processing | ||||
|     decoded_preds = [pred.strip() for pred in decoded_preds] | ||||
|     decoded_labels = [[label.strip()] for label in decoded_labels] | ||||
| 
 | ||||
|     result = metric.compute(predictions=decoded_preds, references=decoded_labels) | ||||
|     return {"bleu": result["score"]} | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Model | ||||
| 
 | ||||
| # %% | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 80 | ||||
| batch_size = 36 | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
| per_device_eval_batch_size = batch_size | ||||
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||||
| steps_per_epoch = len(train_dataset) // train_batch_size | ||||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 5e-5 | ||||
| 
 | ||||
| weight_decay = 0.0 | ||||
| adam_beta1 = 0.9 | ||||
| adam_beta2 = 0.999 | ||||
| adam_epsilon = 1e-8 | ||||
| label_smoothing_factor = 0.0 | ||||
| 
 | ||||
| num_beams = 1 | ||||
| val_max_target_length = None | ||||
| 
 | ||||
| predict_with_generate = True | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # Initialize our training | ||||
| rng = jax.random.PRNGKey(seed) | ||||
| rng, dropout_rng = jax.random.split(rng) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # optimization functions | ||||
| 
 | ||||
| def create_learning_rate_fn( | ||||
|     train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | ||||
| ) -> Callable[[int], jnp.ndarray]: | ||||
|     """Returns a linear warmup, linear_decay learning rate function.""" | ||||
|     steps_per_epoch = train_ds_size // train_batch_size | ||||
|     num_train_steps = steps_per_epoch * num_train_epochs | ||||
|     warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | ||||
|     decay_fn = optax.linear_schedule( | ||||
|         init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | ||||
|     ) | ||||
|     schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | ||||
|     return schedule_fn | ||||
| 
 | ||||
| 
 | ||||
| # Create learning rate schedule | ||||
| linear_decay_lr_schedule_fn = create_learning_rate_fn( | ||||
|     len(train_dataset), | ||||
|     train_batch_size, | ||||
|     num_train_epochs, | ||||
|     warmup_steps, | ||||
|     learning_rate, | ||||
| ) | ||||
| 
 | ||||
| # We use Optax's "masking" functionality to not apply weight decay | ||||
| # to bias and LayerNorm scale parameters. decay_mask_fn returns a | ||||
| # mask boolean with the same structure as the parameters. | ||||
| # The mask is True for parameters that should be decayed. | ||||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|         for layer in flat_params.keys() | ||||
|         if layer_norm_name in "".join(layer).lower() | ||||
|     } | ||||
|     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | ||||
|     return traverse_util.unflatten_dict(flat_mask) | ||||
| 
 | ||||
| # create adam optimizer | ||||
| adamw = optax.adamw( | ||||
|     learning_rate=linear_decay_lr_schedule_fn, | ||||
|     b1=adam_beta1, | ||||
|     b2=adam_beta2, | ||||
|     eps=adam_epsilon, | ||||
|     weight_decay=weight_decay, | ||||
|     mask=decay_mask_fn, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # Training functions | ||||
| class TrainState(train_state.TrainState): | ||||
|     dropout_rng: jnp.ndarray | ||||
| 
 | ||||
|     def replicate(self): | ||||
|         return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) | ||||
| 
 | ||||
| # Ensure model.params is properly initialized (this is just an example) | ||||
| # Normally you would get this from a model initialization call with dummy input | ||||
| params = model.params | ||||
| # Cast parameters to bfloat16 if desired | ||||
| params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) | ||||
| 
 | ||||
| 
 | ||||
| # Setup train state | ||||
| state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng) | ||||
| 
 | ||||
| # label smoothed cross entropy | ||||
| def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | ||||
|     """ | ||||
|     The label smoothing implementation is adapted from Flax's official example: | ||||
|     https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 | ||||
|     """ | ||||
|     vocab_size = logits.shape[-1] | ||||
|     confidence = 1.0 - label_smoothing_factor | ||||
|     low_confidence = (1.0 - confidence) / (vocab_size - 1) | ||||
|     normalizing_constant = -( | ||||
|         confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) | ||||
|     ) | ||||
|     soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) | ||||
| 
 | ||||
|     loss = optax.softmax_cross_entropy(logits, soft_labels) | ||||
|     loss = loss - normalizing_constant | ||||
| 
 | ||||
|     # ignore padded tokens from loss | ||||
|     loss = loss * padding_mask | ||||
|     loss = loss.sum() | ||||
|     num_labels = padding_mask.sum() | ||||
|     return loss, num_labels | ||||
| 
 | ||||
| # Define gradient update step fn | ||||
| def train_step(state, batch, label_smoothing_factor=0.0): | ||||
|     dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | ||||
| 
 | ||||
|     def compute_loss(params): | ||||
|         labels = batch.pop("labels") | ||||
|         logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] | ||||
|         loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) | ||||
|         return loss, num_labels | ||||
| 
 | ||||
|     # compute gradients through computational graph | ||||
|     grad_fn = jax.value_and_grad(compute_loss, has_aux=True) | ||||
|     (loss, num_labels), grad = grad_fn(state.params) | ||||
|     num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # true loss = total loss / total samples | ||||
|     loss = jax.lax.psum(loss, "batch") | ||||
|     loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
| 
 | ||||
|     # true grad = total grad / total samples | ||||
|     grad = jax.lax.psum(grad, "batch") | ||||
|     grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) | ||||
|     new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) | ||||
| 
 | ||||
|     metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
|     return new_state, metrics | ||||
| 
 | ||||
| # Define eval fn | ||||
| def eval_step(params, batch, label_smoothing_factor=0.0): | ||||
|     labels = batch.pop("labels") | ||||
|     logits = model(**batch, params=params, train=False)[0] | ||||
| 
 | ||||
|     loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) | ||||
|     num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # true loss = total loss / total samples | ||||
|     loss = jax.lax.psum(loss, "batch") | ||||
|     loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
| 
 | ||||
|     metrics = {"loss": loss} | ||||
|     return metrics | ||||
| 
 | ||||
| # Define generation function | ||||
| max_length = ( | ||||
|     val_max_target_length if val_max_target_length is not None else model.config.max_length | ||||
| ) | ||||
| num_beams = num_beams if num_beams is not None else model.config.num_beams | ||||
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | ||||
| 
 | ||||
| def generate_step(params, batch): | ||||
|     model.params = params | ||||
|     output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) | ||||
|     return output_ids.sequences | ||||
| 
 | ||||
| # Create parallel version of the train and eval step | ||||
| p_train_step = jax.pmap( | ||||
|     partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,) | ||||
| ) | ||||
| p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch") | ||||
| p_generate_step = jax.pmap(generate_step, "batch") | ||||
| 
 | ||||
| # Replicate the train state on each device | ||||
| state = state.replicate() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| 
 | ||||
| print("***** Running training *****") | ||||
| print(f"  Num examples = {len(train_dataset)}") | ||||
| print(f"  Num Epochs = {num_epochs}") | ||||
| print(f"  Instantaneous batch size per device = {per_device_train_batch_size}") | ||||
| print(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}") | ||||
| print(f"  Total optimization steps = {total_train_steps}") | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| train_time = 0 | ||||
| epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) | ||||
| # epochs = range(num_epochs) | ||||
| for epoch in epochs: | ||||
|     # ======================== Training ================================ | ||||
|     train_start = time.time() | ||||
| 
 | ||||
|     # Create sampling rng | ||||
|     rng, input_rng = jax.random.split(rng) | ||||
|     train_metrics = [] | ||||
| 
 | ||||
|     # Generate an epoch by shuffling sampling indices from the train dataset | ||||
|     train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) | ||||
|     steps_per_epoch = len(train_dataset) // train_batch_size | ||||
|     # train | ||||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|         batch = shard(batch) | ||||
|         state, train_metric = p_train_step(state, batch) | ||||
|         train_metrics.append(train_metric) | ||||
| 
 | ||||
|     train_time += time.time() - train_start | ||||
| 
 | ||||
|     train_metric = unreplicate(train_metric) | ||||
| 
 | ||||
|     epochs.write( | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" | ||||
|         f" {train_metric['learning_rate']})" | ||||
|     ) | ||||
| 
 | ||||
|     # ======================== Evaluating ============================== | ||||
|     # eval_metrics = [] | ||||
|     # eval_preds = [] | ||||
|     # eval_labels = [] | ||||
| 
 | ||||
|     # eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) | ||||
|     # eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) | ||||
|     # for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): | ||||
|     #     # Model forward | ||||
|     #     batch = next(eval_loader) | ||||
|     #     labels = batch["labels"] | ||||
| 
 | ||||
|     #     metrics = pad_shard_unpad(p_eval_step, static_return=True)( | ||||
|     #         state.params, batch, min_device_batch=per_device_eval_batch_size | ||||
|     #     ) | ||||
|     #     eval_metrics.append(metrics) | ||||
| 
 | ||||
|     #     # generation | ||||
|     #     if predict_with_generate: | ||||
|     #         generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) | ||||
|     #         eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) | ||||
|     #         eval_labels.extend(labels) | ||||
| 
 | ||||
|     # # normalize eval metrics | ||||
|     # eval_metrics = get_metrics(eval_metrics) | ||||
|     # eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) | ||||
| 
 | ||||
|     # compute metrics | ||||
|     # rouge_desc = "" | ||||
|     # if predict_with_generate: | ||||
|     #     rouge_metrics = compute_metrics(eval_preds, eval_labels) | ||||
|     #     eval_metrics.update(rouge_metrics) | ||||
|     #     rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) | ||||
| 
 | ||||
|     # # Print metrics and update progress bar | ||||
|     # desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" | ||||
|     # epochs.write(desc) | ||||
|     # epochs.desc = desc | ||||
| 
 | ||||
|     # Save metrics | ||||
|     # if has_tensorboard and jax.process_index() == 0: | ||||
|     #     cur_step = epoch * (len(train_dataset) // train_batch_size) | ||||
|     #     write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) | ||||
| 
 | ||||
|     output_dir = save_path | ||||
|     # save checkpoint after each epoch and push checkpoint to the hub | ||||
|     if jax.process_index() == 0: | ||||
|         params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) | ||||
|         model.save_pretrained(output_dir, params=params) | ||||
|         tokenizer.save_pretrained(output_dir) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # #  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue