Feat: jax implementation of t5 training and prediction
This commit is contained in:
commit
f523560141
|
@ -0,0 +1,3 @@
|
|||
*.ipynb
|
||||
t5_*/
|
||||
exports/
|
|
@ -0,0 +1,357 @@
|
|||
name: jax
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=2_gnu
|
||||
- _sysroot_linux-64_curr_repodata_hack=3=h69a702a_16
|
||||
- absl-py=2.1.0=pyhd8ed1ab_0
|
||||
- aiohappyeyeballs=2.4.0=pyhd8ed1ab_0
|
||||
- aiohttp=3.10.5=py311h61187de_0
|
||||
- aiosignal=1.3.1=pyhd8ed1ab_0
|
||||
- alsa-lib=1.2.12=h4ab18f5_0
|
||||
- aom=3.9.1=hac33072_0
|
||||
- arrow=1.3.0=pyhd8ed1ab_0
|
||||
- asttokens=2.4.1=pyhd8ed1ab_0
|
||||
- attrs=24.2.0=pyh71513ae_0
|
||||
- aws-c-auth=0.7.29=h03582ad_1
|
||||
- aws-c-cal=0.7.4=hfd43aa1_1
|
||||
- aws-c-common=0.9.28=hb9d3cd8_0
|
||||
- aws-c-compression=0.2.19=h756ea98_1
|
||||
- aws-c-event-stream=0.4.3=h235a6dd_1
|
||||
- aws-c-http=0.8.8=h5e77a74_2
|
||||
- aws-c-io=0.14.18=hc2627b9_9
|
||||
- aws-c-mqtt=0.10.4=h01636a3_19
|
||||
- aws-c-s3=0.6.5=h191b246_2
|
||||
- aws-c-sdkutils=0.1.19=h756ea98_3
|
||||
- aws-checksums=0.1.18=h756ea98_11
|
||||
- aws-crt-cpp=0.28.2=h29c84ef_4
|
||||
- aws-sdk-cpp=1.11.379=h5a9005d_9
|
||||
- azure-core-cpp=1.13.0=h935415a_0
|
||||
- azure-identity-cpp=1.8.0=hd126650_2
|
||||
- azure-storage-blobs-cpp=12.12.0=hd2e3451_0
|
||||
- azure-storage-common-cpp=12.7.0=h10ac4d7_1
|
||||
- azure-storage-files-datalake-cpp=12.11.0=h325d260_1
|
||||
- binaryornot=0.4.4=py_1
|
||||
- binutils_impl_linux-64=2.40=ha1999f0_7
|
||||
- binutils_linux-64=2.40=hb3c18ed_2
|
||||
- blosc=1.21.6=hef167b5_0
|
||||
- brotli=1.1.0=hb9d3cd8_2
|
||||
- brotli-bin=1.1.0=hb9d3cd8_2
|
||||
- brotli-python=1.1.0=py311hfdbb021_2
|
||||
- bzip2=1.0.8=h4bc722e_7
|
||||
- c-ares=1.33.1=heb4867d_0
|
||||
- ca-certificates=2024.8.30=hbcca054_0
|
||||
- cairo=1.18.0=hebfffa5_3
|
||||
- certifi=2024.8.30=pyhd8ed1ab_0
|
||||
- cffi=1.17.1=py311hf29c0ef_0
|
||||
- chardet=5.2.0=py311h38be061_2
|
||||
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
||||
- chex=0.1.86=pyhd8ed1ab_0
|
||||
- click=8.1.7=unix_pyh707e725_0
|
||||
- colorama=0.4.6=pyhd8ed1ab_0
|
||||
- comm=0.2.2=pyhd8ed1ab_0
|
||||
- contourpy=1.3.0=py311hd18a35c_1
|
||||
- cookiecutter=2.6.0=pyhca7485f_0
|
||||
- cuda-cccl_linux-64=12.6.37=ha770c72_0
|
||||
- cuda-crt-dev_linux-64=12.6.68=ha770c72_0
|
||||
- cuda-crt-tools=12.6.68=ha770c72_0
|
||||
- cuda-cudart=12.6.68=h5888daf_0
|
||||
- cuda-cudart-dev=12.6.68=h5888daf_0
|
||||
- cuda-cudart-dev_linux-64=12.6.68=h3f2d84a_0
|
||||
- cuda-cudart-static=12.6.68=h5888daf_0
|
||||
- cuda-cudart-static_linux-64=12.6.68=h3f2d84a_0
|
||||
- cuda-cudart_linux-64=12.6.68=h3f2d84a_0
|
||||
- cuda-cupti=12.6.68=h5888daf_0
|
||||
- cuda-driver-dev_linux-64=12.6.68=h3f2d84a_0
|
||||
- cuda-nvcc=12.6.68=hcdd1206_0
|
||||
- cuda-nvcc-dev_linux-64=12.6.68=he91c749_0
|
||||
- cuda-nvcc-impl=12.6.68=h85509e4_0
|
||||
- cuda-nvcc-tools=12.6.68=he02047a_0
|
||||
- cuda-nvcc_linux-64=12.6.68=h8a487aa_0
|
||||
- cuda-nvrtc=12.6.68=h5888daf_0
|
||||
- cuda-nvtx=12.6.68=h5888daf_0
|
||||
- cuda-nvvm-dev_linux-64=12.6.68=ha770c72_0
|
||||
- cuda-nvvm-impl=12.6.68=he02047a_0
|
||||
- cuda-nvvm-tools=12.6.68=he02047a_0
|
||||
- cuda-version=12.6=h7480c83_3
|
||||
- cudnn=9.2.1.18=hbc370b7_0
|
||||
- cycler=0.12.1=pyhd8ed1ab_0
|
||||
- datasets=2.21.0=pyhd8ed1ab_0
|
||||
- dav1d=1.2.1=hd590300_0
|
||||
- dbus=1.13.6=h5008d03_3
|
||||
- debugpy=1.8.5=py311hfdbb021_1
|
||||
- decorator=5.1.1=pyhd8ed1ab_0
|
||||
- dill=0.3.8=pyhd8ed1ab_0
|
||||
- double-conversion=3.3.0=h59595ed_0
|
||||
- etils=1.9.4=pyhd8ed1ab_0
|
||||
- evaluate=0.4.1=pyhd8ed1ab_0
|
||||
- exceptiongroup=1.2.2=pyhd8ed1ab_0
|
||||
- executing=2.1.0=pyhd8ed1ab_0
|
||||
- expat=2.6.3=h5888daf_0
|
||||
- filelock=3.16.0=pyhd8ed1ab_0
|
||||
- flax=0.9.0=pyhd8ed1ab_0
|
||||
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
||||
- font-ttf-inconsolata=3.000=h77eed37_0
|
||||
- font-ttf-source-code-pro=2.038=h77eed37_0
|
||||
- font-ttf-ubuntu=0.83=h77eed37_2
|
||||
- fontconfig=2.14.2=h14ed4e7_0
|
||||
- fonts-conda-ecosystem=1=0
|
||||
- fonts-conda-forge=1=0
|
||||
- fonttools=4.53.1=py311h9ecbd09_1
|
||||
- freetype=2.12.1=h267a509_2
|
||||
- frozenlist=1.4.1=py311h9ecbd09_1
|
||||
- fsspec=2024.5.0=pyhff2d567_0
|
||||
- gcc_impl_linux-64=13.3.0=hfea6d02_1
|
||||
- gcc_linux-64=13.3.0=hc28eda2_2
|
||||
- gflags=2.2.2=he1b5a44_1004
|
||||
- glog=0.7.1=hbabe93e_0
|
||||
- graphite2=1.3.13=h59595ed_1003
|
||||
- gxx_impl_linux-64=13.3.0=hdbfa832_1
|
||||
- gxx_linux-64=13.3.0=h6834431_2
|
||||
- h2=4.1.0=pyhd8ed1ab_0
|
||||
- harfbuzz=9.0.0=hda332d3_1
|
||||
- hpack=4.0.0=pyh9f0ad1d_0
|
||||
- huggingface_hub=0.24.6=pyhd8ed1ab_0
|
||||
- hyperframe=6.0.1=pyhd8ed1ab_0
|
||||
- icu=75.1=he02047a_0
|
||||
- idna=3.8=pyhd8ed1ab_0
|
||||
- importlib-metadata=8.4.0=pyha770c72_0
|
||||
- importlib_metadata=8.4.0=hd8ed1ab_0
|
||||
- importlib_resources=6.4.5=pyhd8ed1ab_0
|
||||
- ipykernel=6.29.5=pyh3099207_0
|
||||
- ipython=8.27.0=pyh707e725_0
|
||||
- jax=0.4.31=pyhd8ed1ab_1
|
||||
- jaxlib=0.4.31=cuda120py311hd88f13b_201
|
||||
- jedi=0.19.1=pyhd8ed1ab_0
|
||||
- jinja2=3.1.4=pyhd8ed1ab_0
|
||||
- joblib=1.4.2=pyhd8ed1ab_0
|
||||
- jsonschema=4.23.0=pyhd8ed1ab_0
|
||||
- jsonschema-specifications=2023.12.1=pyhd8ed1ab_0
|
||||
- jupyter_client=8.6.2=pyhd8ed1ab_0
|
||||
- jupyter_core=5.7.2=py311h38be061_0
|
||||
- jupytext=1.16.4=pyh80e38bb_0
|
||||
- kernel-headers_linux-64=3.10.0=h4a8ded7_16
|
||||
- keyutils=1.6.1=h166bdaf_0
|
||||
- kiwisolver=1.4.7=py311hd18a35c_0
|
||||
- krb5=1.21.3=h659f571_0
|
||||
- lcms2=2.16=hb7c19ff_0
|
||||
- ld_impl_linux-64=2.40=hf3520f5_7
|
||||
- lerc=4.0.0=h27087fc_0
|
||||
- libabseil=20240116.2=cxx17_he02047a_1
|
||||
- libarrow=17.0.0=h8d2e343_13_cpu
|
||||
- libarrow-acero=17.0.0=h5888daf_13_cpu
|
||||
- libarrow-dataset=17.0.0=h5888daf_13_cpu
|
||||
- libarrow-substrait=17.0.0=hf54134d_13_cpu
|
||||
- libavif16=1.1.1=h104a339_1
|
||||
- libblas=3.9.0=23_linux64_openblas
|
||||
- libbrotlicommon=1.1.0=hb9d3cd8_2
|
||||
- libbrotlidec=1.1.0=hb9d3cd8_2
|
||||
- libbrotlienc=1.1.0=hb9d3cd8_2
|
||||
- libcblas=3.9.0=23_linux64_openblas
|
||||
- libclang-cpp18.1=18.1.8=default_hf981a13_4
|
||||
- libclang13=18.1.8=default_h9def88c_4
|
||||
- libcrc32c=1.1.2=h9c3ff4c_0
|
||||
- libcublas=12.6.1.4=h5888daf_0
|
||||
- libcufft=11.2.6.59=h5888daf_0
|
||||
- libcups=2.3.3=h4637d8d_4
|
||||
- libcurand=10.3.7.68=h5888daf_0
|
||||
- libcurl=8.9.1=hdb1bdb2_0
|
||||
- libcusolver=11.6.4.69=h5888daf_0
|
||||
- libcusparse=12.5.3.3=h5888daf_0
|
||||
- libdeflate=1.21=h4bc722e_0
|
||||
- libdrm=2.4.123=hb9d3cd8_0
|
||||
- libedit=3.1.20191231=he28a2e2_2
|
||||
- libegl=1.7.0=ha4b6fd6_0
|
||||
- libev=4.33=hd590300_2
|
||||
- libevent=2.1.12=hf998b51_1
|
||||
- libexpat=2.6.3=h5888daf_0
|
||||
- libffi=3.4.2=h7f98852_5
|
||||
- libgcc=14.1.0=h77fa898_1
|
||||
- libgcc-devel_linux-64=13.3.0=h84ea5a7_101
|
||||
- libgcc-ng=14.1.0=h69a702a_1
|
||||
- libgfortran=14.1.0=h69a702a_1
|
||||
- libgfortran-ng=14.1.0=h69a702a_1
|
||||
- libgfortran5=14.1.0=hc5f4f2c_1
|
||||
- libgl=1.7.0=ha4b6fd6_0
|
||||
- libglib=2.80.3=h315aac3_2
|
||||
- libglvnd=1.7.0=ha4b6fd6_0
|
||||
- libglx=1.7.0=ha4b6fd6_0
|
||||
- libgomp=14.1.0=h77fa898_1
|
||||
- libgoogle-cloud=2.28.0=h26d7fe4_0
|
||||
- libgoogle-cloud-storage=2.28.0=ha262f82_0
|
||||
- libgrpc=1.62.2=h15f2491_0
|
||||
- libiconv=1.17=hd590300_2
|
||||
- libjpeg-turbo=3.0.0=hd590300_1
|
||||
- liblapack=3.9.0=23_linux64_openblas
|
||||
- libllvm18=18.1.8=h8b73ec9_2
|
||||
- libnghttp2=1.58.0=h47da74e_1
|
||||
- libnsl=2.0.1=hd590300_0
|
||||
- libnvjitlink=12.6.68=h5888daf_0
|
||||
- libopenblas=0.3.27=pthreads_hac2b453_1
|
||||
- libparquet=17.0.0=h39682fd_13_cpu
|
||||
- libpciaccess=0.18=hd590300_0
|
||||
- libpng=1.6.43=h2797004_0
|
||||
- libpq=16.4=h2d7952a_1
|
||||
- libprotobuf=4.25.3=h08a7969_0
|
||||
- libre2-11=2023.09.01=h5a48ba9_2
|
||||
- libsanitizer=13.3.0=heb74ff8_1
|
||||
- libsodium=1.0.20=h4ab18f5_0
|
||||
- libsqlite=3.46.1=hadc24fc_0
|
||||
- libssh2=1.11.0=h0841786_0
|
||||
- libstdcxx=14.1.0=hc0a3c3a_1
|
||||
- libstdcxx-devel_linux-64=13.3.0=h84ea5a7_101
|
||||
- libstdcxx-ng=14.1.0=h4852527_1
|
||||
- libthrift=0.20.0=h0e7cc3e_1
|
||||
- libtiff=4.6.0=h46a8edc_4
|
||||
- libutf8proc=2.8.0=h166bdaf_0
|
||||
- libuuid=2.38.1=h0b41bf4_0
|
||||
- libwebp-base=1.4.0=hd590300_0
|
||||
- libxcb=1.16=hb9d3cd8_1
|
||||
- libxcrypt=4.4.36=hd590300_1
|
||||
- libxkbcommon=1.7.0=h2c5496b_1
|
||||
- libxml2=2.12.7=he7c6b58_4
|
||||
- libxslt=1.1.39=h76b75d6_0
|
||||
- libzlib=1.3.1=h4ab18f5_1
|
||||
- lxml=5.3.0=py311hcfaa980_1
|
||||
- lz4-c=1.9.4=hcb278e6_0
|
||||
- markdown-it-py=3.0.0=pyhd8ed1ab_0
|
||||
- markupsafe=2.1.5=py311h9ecbd09_1
|
||||
- matplotlib=3.9.2=py311h38be061_0
|
||||
- matplotlib-base=3.9.2=py311h74b4f7c_0
|
||||
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
|
||||
- mdit-py-plugins=0.4.1=pyhd8ed1ab_0
|
||||
- mdurl=0.1.2=pyhd8ed1ab_0
|
||||
- ml_dtypes=0.4.0=py311h7db5c69_2
|
||||
- msgpack-python=1.0.8=py311hd18a35c_1
|
||||
- multidict=6.1.0=py311h9ecbd09_0
|
||||
- multiprocess=0.70.16=py311h9ecbd09_1
|
||||
- munkres=1.1.4=pyh9f0ad1d_0
|
||||
- mysql-common=9.0.1=h70512c7_0
|
||||
- mysql-libs=9.0.1=ha479ceb_0
|
||||
- nbformat=5.10.4=pyhd8ed1ab_0
|
||||
- nccl=2.22.3.1=hbc370b7_1
|
||||
- ncurses=6.5=he02047a_1
|
||||
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
||||
- nltk=3.9.1=pyhd8ed1ab_0
|
||||
- numpy=1.26.4=py311h64a7726_0
|
||||
- openjpeg=2.5.2=h488ebb8_0
|
||||
- openssl=3.3.2=hb9d3cd8_0
|
||||
- opt-einsum=3.3.0=hd8ed1ab_2
|
||||
- opt_einsum=3.3.0=pyhc1e730c_2
|
||||
- optax=0.2.2=pyhd8ed1ab_1
|
||||
- orbax-checkpoint=0.4.4=pyhd8ed1ab_0
|
||||
- orc=2.0.2=h669347b_0
|
||||
- packaging=24.1=pyhd8ed1ab_0
|
||||
- pandas=2.2.2=py311h14de704_1
|
||||
- parso=0.8.4=pyhd8ed1ab_0
|
||||
- pcre2=10.44=hba22ea6_2
|
||||
- pexpect=4.9.0=pyhd8ed1ab_0
|
||||
- pickleshare=0.7.5=py_1003
|
||||
- pillow=10.4.0=py311h82a398c_0
|
||||
- pip=24.2=pyh8b19718_1
|
||||
- pixman=0.43.2=h59595ed_0
|
||||
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
|
||||
- platformdirs=4.3.2=pyhd8ed1ab_0
|
||||
- portalocker=2.10.1=py311h38be061_0
|
||||
- prompt-toolkit=3.0.47=pyha770c72_0
|
||||
- protobuf=4.25.3=py311hbffca5d_1
|
||||
- psutil=6.0.0=py311h9ecbd09_1
|
||||
- pthread-stubs=0.4=h36c2ea0_1001
|
||||
- ptyprocess=0.7.0=pyhd3deb0d_0
|
||||
- pure_eval=0.2.3=pyhd8ed1ab_0
|
||||
- pyarrow=17.0.0=py311hbd00459_1
|
||||
- pyarrow-core=17.0.0=py311h4510849_1_cpu
|
||||
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
|
||||
- pybind11-abi=4=hd8ed1ab_3
|
||||
- pycparser=2.22=pyhd8ed1ab_0
|
||||
- pygments=2.18.0=pyhd8ed1ab_0
|
||||
- pyparsing=3.1.4=pyhd8ed1ab_0
|
||||
- pyside6=6.7.2=py311hba19f1e_2
|
||||
- pysocks=1.7.1=pyha2e5f31_6
|
||||
- python=3.11.9=hb806964_0_cpython
|
||||
- python-dateutil=2.9.0=pyhd8ed1ab_0
|
||||
- python-fastjsonschema=2.20.0=pyhd8ed1ab_0
|
||||
- python-slugify=8.0.4=pyhd8ed1ab_0
|
||||
- python-tzdata=2024.1=pyhd8ed1ab_0
|
||||
- python-xxhash=3.5.0=py311h9ecbd09_1
|
||||
- python_abi=3.11=5_cp311
|
||||
- pytz=2024.1=pyhd8ed1ab_0
|
||||
- pyyaml=6.0.2=py311h9ecbd09_1
|
||||
- pyzmq=26.2.0=py311h7deb3e3_2
|
||||
- qhull=2020.2=h434a139_5
|
||||
- qt6-main=6.7.2=hb12f9c5_5
|
||||
- rav1e=0.6.6=he8a937b_2
|
||||
- re2=2023.09.01=h7f4b329_2
|
||||
- readline=8.2=h8228510_1
|
||||
- referencing=0.35.1=pyhd8ed1ab_0
|
||||
- regex=2024.7.24=py311h9ecbd09_1
|
||||
- requests=2.32.3=pyhd8ed1ab_0
|
||||
- responses=0.18.0=pyhd8ed1ab_0
|
||||
- rich=13.7.1=pyhd8ed1ab_0
|
||||
- rpds-py=0.20.0=py311h9e33e62_1
|
||||
- s2n=1.5.2=h7b32b05_0
|
||||
- sacrebleu=2.4.1=pyhd8ed1ab_1
|
||||
- safetensors=0.4.5=py311h9e33e62_0
|
||||
- scipy=1.14.1=py311he1f765f_0
|
||||
- setuptools=73.0.1=pyhd8ed1ab_0
|
||||
- six=1.16.0=pyh6c4a22f_0
|
||||
- snappy=1.2.1=ha2e4443_0
|
||||
- stack_data=0.6.2=pyhd8ed1ab_0
|
||||
- svt-av1=2.2.1=h5888daf_0
|
||||
- sysroot_linux-64=2.17=h4a8ded7_16
|
||||
- tabulate=0.9.0=pyhd8ed1ab_1
|
||||
- tensorstore=0.1.62=py311he109767_1
|
||||
- text-unidecode=1.3=pyhd8ed1ab_1
|
||||
- tk=8.6.13=noxft_h4845f30_101
|
||||
- tokenizers=0.19.1=py311h6640629_0
|
||||
- tomli=2.0.1=pyhd8ed1ab_0
|
||||
- toolz=0.12.1=pyhd8ed1ab_0
|
||||
- tornado=6.4.1=py311h9ecbd09_1
|
||||
- tqdm=4.66.5=pyhd8ed1ab_0
|
||||
- traitlets=5.14.3=pyhd8ed1ab_0
|
||||
- transformers=4.41.2=pyhd8ed1ab_0
|
||||
- types-python-dateutil=2.9.0.20240906=pyhd8ed1ab_0
|
||||
- typing=3.10.0.0=pyhd8ed1ab_1
|
||||
- typing-extensions=4.12.2=hd8ed1ab_0
|
||||
- typing_extensions=4.12.2=pyha770c72_0
|
||||
- tzdata=2024a=h8827d51_1
|
||||
- urllib3=2.2.2=pyhd8ed1ab_1
|
||||
- wayland=1.23.1=h3e06ad9_0
|
||||
- wcwidth=0.2.13=pyhd8ed1ab_0
|
||||
- wheel=0.44.0=pyhd8ed1ab_0
|
||||
- xcb-util=0.4.1=hb711507_2
|
||||
- xcb-util-cursor=0.1.4=h4ab18f5_2
|
||||
- xcb-util-image=0.4.0=hb711507_2
|
||||
- xcb-util-keysyms=0.4.1=hb711507_0
|
||||
- xcb-util-renderutil=0.3.10=hb711507_0
|
||||
- xcb-util-wm=0.4.2=hb711507_0
|
||||
- xkeyboard-config=2.42=h4ab18f5_0
|
||||
- xorg-fixesproto=5.0=h7f98852_1002
|
||||
- xorg-inputproto=2.3.2=h7f98852_1002
|
||||
- xorg-kbproto=1.0.7=h7f98852_1002
|
||||
- xorg-libice=1.1.1=hd590300_0
|
||||
- xorg-libsm=1.2.4=h7391055_0
|
||||
- xorg-libx11=1.8.9=hb711507_1
|
||||
- xorg-libxau=1.0.11=hd590300_0
|
||||
- xorg-libxdmcp=1.1.3=h7f98852_0
|
||||
- xorg-libxext=1.3.4=h0b41bf4_2
|
||||
- xorg-libxfixes=5.0.3=h7f98852_1004
|
||||
- xorg-libxi=1.7.10=h4bc722e_1
|
||||
- xorg-libxrender=0.9.11=hd590300_0
|
||||
- xorg-libxtst=1.2.5=h4bc722e_0
|
||||
- xorg-libxxf86vm=1.1.5=h4bc722e_1
|
||||
- xorg-recordproto=1.14.2=h7f98852_1002
|
||||
- xorg-renderproto=0.11.1=h7f98852_1002
|
||||
- xorg-xextproto=7.3.0=h0b41bf4_1003
|
||||
- xorg-xproto=7.0.31=h7f98852_1007
|
||||
- xxhash=0.8.2=hd590300_0
|
||||
- xz=5.2.6=h166bdaf_0
|
||||
- yaml=0.2.5=h7f98852_2
|
||||
- yarl=1.11.1=py311h9ecbd09_0
|
||||
- zeromq=4.3.5=ha4adb4c_5
|
||||
- zipp=3.20.1=pyhd8ed1ab_0
|
||||
- zlib=1.3.1=h4ab18f5_1
|
||||
- zstandard=0.23.0=py311hbc35293_1
|
||||
- zstd=1.5.6=ha6fb4c9_0
|
||||
|
|
@ -0,0 +1,620 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.16.4
|
||||
# kernelspec:
|
||||
# display_name: jax
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
# # T5 implementation using jax
|
||||
|
||||
# %% [markdown]
|
||||
# ## import
|
||||
|
||||
# %% [raw]
|
||||
# import json
|
||||
# import logging
|
||||
# import math
|
||||
# import os
|
||||
# import sys
|
||||
# import time
|
||||
# from dataclasses import asdict, dataclass, field
|
||||
# from enum import Enum
|
||||
# from functools import partial
|
||||
# from pathlib import Path
|
||||
# from typing import Callable, Optional
|
||||
#
|
||||
# import datasets
|
||||
# import evaluate
|
||||
# import jax
|
||||
# import jax.numpy as jnp
|
||||
# import nltk # Here to have a nice missing dependency error message early on
|
||||
# import numpy as np
|
||||
# import optax
|
||||
# from datasets import Dataset, load_dataset
|
||||
# from filelock import FileLock
|
||||
# from flax import jax_utils, traverse_util
|
||||
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
# from flax.training import train_state
|
||||
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
# from tqdm import tqdm
|
||||
#
|
||||
# import transformers
|
||||
# from transformers import (
|
||||
# CONFIG_MAPPING,
|
||||
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
# AutoConfig,
|
||||
# AutoTokenizer,
|
||||
# FlaxAutoModelForSeq2SeqLM,
|
||||
# HfArgumentParser,
|
||||
# is_tensorboard_available,
|
||||
# )
|
||||
# from transformers.utils import is_offline_mode, send_example_telemetry
|
||||
#
|
||||
#
|
||||
# logger = logging.getLogger(__name__)
|
||||
#
|
||||
# try:
|
||||
# nltk.data.find("tokenizers/punkt")
|
||||
# except (LookupError, OSError):
|
||||
# if is_offline_mode():
|
||||
# raise LookupError(
|
||||
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
||||
# )
|
||||
# with FileLock(".lock") as lock:
|
||||
# nltk.download("punkt", quiet=True)
|
||||
#
|
||||
#
|
||||
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
||||
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
# %%
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
import math
|
||||
|
||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||
jax.config.update("jax_default_matmul_precision", "high")
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
|
||||
|
||||
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, load_dataset
|
||||
import evaluate
|
||||
from tqdm import tqdm
|
||||
from datasets import load_from_disk
|
||||
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
# %%
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_gpu_enable_triton_softmax_fusion=True '
|
||||
'--xla_gpu_triton_gemm_any=True '
|
||||
)
|
||||
|
||||
os.environ.update({
|
||||
"NCCL_LL128_BUFFSIZE": "-2",
|
||||
"NCCL_LL_BUFFSIZE": "-2",
|
||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||
})
|
||||
|
||||
# %%
|
||||
from jax.lib import xla_bridge
|
||||
print(xla_bridge.get_backend().platform)
|
||||
|
||||
|
||||
# %%
|
||||
# nltk.download('punkt')
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except (LookupError, OSError):
|
||||
if is_offline_mode():
|
||||
raise LookupError(
|
||||
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
||||
)
|
||||
with FileLock(".lock") as lock:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Prepare datasets
|
||||
|
||||
# %%
|
||||
# load model
|
||||
model_name_or_path = "t5-small" # Replace with your specific model name
|
||||
|
||||
# Load configuration
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
|
||||
# Load model
|
||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name_or_path
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||||
|
||||
|
||||
# %%
|
||||
# Path to saved combined_dataset
|
||||
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||
save_path = 't5_80_1'
|
||||
# file_path = 'combined_data'
|
||||
split_datasets = load_from_disk(file_path)
|
||||
|
||||
# prepare tokenizer
|
||||
from transformers import T5TokenizerFast
|
||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
max_length = 86
|
||||
|
||||
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
||||
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
||||
# for that dynamically import the `shift_tokens_right` function from the model file
|
||||
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||
def preprocess_function(example):
|
||||
input = example['input']
|
||||
target = example['output']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
text_target=target,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
labels = tokenizer(
|
||||
input,
|
||||
text_target=target,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
decoder_input_ids = shift_tokens_right_fn(
|
||||
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||||
)
|
||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||
|
||||
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||||
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
# map maps function to each "row" in the dataset
|
||||
# aka the data in the immediate nesting
|
||||
tokenized_datasets = split_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=1,
|
||||
remove_columns=split_datasets["train"].column_names,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
tokenized_datasets
|
||||
|
||||
# %%
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
eval_dataset = tokenized_datasets["validation"]
|
||||
|
||||
|
||||
# %%
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||
"""
|
||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||
"""
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.asarray(batch_idx)
|
||||
else:
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
if drop_last:
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
else:
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
yield batch
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# Now we have model inputs in terms of the variable tokenized_datasets
|
||||
|
||||
# %%
|
||||
# metric
|
||||
metric = evaluate.load("sacrebleu")
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
# rougeLSum expects newline after each sentence
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
# def compute_metrics(preds, labels):
|
||||
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
#
|
||||
# # Some simple post-processing
|
||||
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
#
|
||||
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||
# result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
# result["gen_len"] = np.mean(prediction_lens)
|
||||
# return result
|
||||
|
||||
def compute_metrics(preds, labels):
|
||||
# In case the model returns more than the prediction logits
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Replace -100s in the labels as we can't decode them
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
# Some simple post-processing
|
||||
decoded_preds = [pred.strip() for pred in decoded_preds]
|
||||
decoded_labels = [[label.strip()] for label in decoded_labels]
|
||||
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||
return {"bleu": result["score"]}
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# # Model
|
||||
|
||||
# %%
|
||||
# Store some constant
|
||||
seed = 117
|
||||
num_epochs = 80
|
||||
batch_size = 96
|
||||
num_train_epochs = num_epochs
|
||||
per_device_train_batch_size = batch_size
|
||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||
per_device_eval_batch_size = batch_size
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
warmup_steps = 0
|
||||
learning_rate = 5e-5
|
||||
|
||||
weight_decay = 0.0
|
||||
adam_beta1 = 0.9
|
||||
adam_beta2 = 0.999
|
||||
adam_epsilon = 1e-8
|
||||
label_smoothing_factor = 0.0
|
||||
|
||||
num_beams = 1
|
||||
val_max_target_length = None
|
||||
|
||||
predict_with_generate = True
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(seed)
|
||||
rng, dropout_rng = jax.random.split(rng)
|
||||
|
||||
|
||||
# %%
|
||||
# optimization functions
|
||||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||
)
|
||||
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||
return schedule_fn
|
||||
|
||||
|
||||
# Create learning rate schedule
|
||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||
len(train_dataset),
|
||||
train_batch_size,
|
||||
num_train_epochs,
|
||||
warmup_steps,
|
||||
learning_rate,
|
||||
)
|
||||
|
||||
# We use Optax's "masking" functionality to not apply weight decay
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = {
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
}
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=adam_beta1,
|
||||
b2=adam_beta2,
|
||||
eps=adam_epsilon,
|
||||
weight_decay=weight_decay,
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Training functions
|
||||
class TrainState(train_state.TrainState):
|
||||
dropout_rng: jnp.ndarray
|
||||
|
||||
def replicate(self):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
# Ensure model.params is properly initialized (this is just an example)
|
||||
# Normally you would get this from a model initialization call with dummy input
|
||||
params = model.params
|
||||
# Cast parameters to bfloat16 if desired
|
||||
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||
|
||||
|
||||
# Setup train state
|
||||
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
||||
|
||||
# label smoothed cross entropy
|
||||
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||
"""
|
||||
The label smoothing implementation is adapted from Flax's official example:
|
||||
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||
"""
|
||||
vocab_size = logits.shape[-1]
|
||||
confidence = 1.0 - label_smoothing_factor
|
||||
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||
normalizing_constant = -(
|
||||
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||
)
|
||||
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||
|
||||
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||
loss = loss - normalizing_constant
|
||||
|
||||
# ignore padded tokens from loss
|
||||
loss = loss * padding_mask
|
||||
loss = loss.sum()
|
||||
num_labels = padding_mask.sum()
|
||||
return loss, num_labels
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||
|
||||
def compute_loss(params):
|
||||
labels = batch.pop("labels")
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss, num_labels
|
||||
|
||||
# compute gradients through computational graph
|
||||
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||
labels = batch.pop("labels")
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
metrics = {"loss": loss}
|
||||
return metrics
|
||||
|
||||
# Define generation function
|
||||
max_length = (
|
||||
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||
)
|
||||
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||
|
||||
def generate_step(params, batch):
|
||||
model.params = params
|
||||
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
||||
return output_ids.sequences
|
||||
|
||||
# Create parallel version of the train and eval step
|
||||
p_train_step = jax.pmap(
|
||||
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
||||
)
|
||||
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
||||
p_generate_step = jax.pmap(generate_step, "batch")
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = state.replicate()
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
print("***** Running training *****")
|
||||
print(f" Num examples = {len(train_dataset)}")
|
||||
print(f" Num Epochs = {num_epochs}")
|
||||
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||
print(f" Total optimization steps = {total_train_steps}")
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
# epochs = range(num_epochs)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
train_metrics = []
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
# train
|
||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||
batch = next(train_loader)
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
|
||||
train_metric = unreplicate(train_metric)
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||
f" {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
eval_metrics = []
|
||||
eval_preds = []
|
||||
eval_labels = []
|
||||
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
||||
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = next(eval_loader)
|
||||
labels = batch["labels"]
|
||||
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# generation
|
||||
if predict_with_generate:
|
||||
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
||||
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||
eval_labels.extend(labels)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# compute metrics
|
||||
rouge_desc = ""
|
||||
if predict_with_generate:
|
||||
rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
||||
eval_metrics.update(rouge_metrics)
|
||||
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
||||
|
||||
# Print metrics and update progress bar
|
||||
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
||||
epochs.write(desc)
|
||||
epochs.desc = desc
|
||||
|
||||
# Save metrics
|
||||
# if has_tensorboard and jax.process_index() == 0:
|
||||
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
output_dir = save_path
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(output_dir, params=params)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# #
|
|
@ -0,0 +1,386 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.16.4
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
# # prediction code
|
||||
# ## import and process test data
|
||||
|
||||
|
||||
# %%
|
||||
# import libraries
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
import math
|
||||
|
||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||
jax.config.update("jax_default_matmul_precision", "high")
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
|
||||
|
||||
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, load_dataset
|
||||
import evaluate
|
||||
from tqdm import tqdm
|
||||
from datasets import load_from_disk
|
||||
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
||||
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
||||
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
||||
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
||||
|
||||
# Ensure to include 'ships_idx' in the fields list
|
||||
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||||
|
||||
# Load the dataset
|
||||
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||||
|
||||
def process_df(df):
|
||||
output_list = [{
|
||||
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
||||
# 'input': f"<DESC>{row['tag_description']}<DESC>",
|
||||
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||||
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||||
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||||
'answer': f"{row['thing']} {row['property']}",
|
||||
'answer_thing': row['thing'],
|
||||
'answer_property': row['property'],
|
||||
} for _, row in df.iterrows()]
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
# takes 1 minute to run without batching
|
||||
test_dataset = Dataset.from_list(process_df(df))
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Load model for attributes
|
||||
|
||||
# %%
|
||||
# load model
|
||||
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
||||
|
||||
# Load configuration
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
|
||||
# Load model
|
||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name_or_path
|
||||
)
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Tokenizer
|
||||
|
||||
# %%
|
||||
# prepare tokenizer
|
||||
from transformers import T5TokenizerFast
|
||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
max_length = 86
|
||||
|
||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||
def preprocess_function(example):
|
||||
input = example['input']
|
||||
target = example['output']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
text_target=target,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
labels = tokenizer(
|
||||
input,
|
||||
text_target=target,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
decoder_input_ids = shift_tokens_right_fn(
|
||||
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||||
)
|
||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||
|
||||
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||||
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
# map maps function to each "row" in the dataset
|
||||
# aka the data in the immediate nesting
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=1,
|
||||
remove_columns=test_dataset.column_names,
|
||||
)
|
||||
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||
"""
|
||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||
"""
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.asarray(batch_idx)
|
||||
else:
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
if drop_last:
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
else:
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
yield batch
|
||||
|
||||
# %% [markdown]
|
||||
# # Model Training
|
||||
|
||||
# %%
|
||||
seed = 117
|
||||
num_epochs = 80
|
||||
batch_size = 96
|
||||
num_train_epochs = num_epochs
|
||||
per_device_train_batch_size = batch_size
|
||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||
per_device_eval_batch_size = batch_size
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
steps_per_epoch = len(test_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
warmup_steps = 0
|
||||
learning_rate = 5e-5
|
||||
|
||||
weight_decay = 0.0
|
||||
adam_beta1 = 0.9
|
||||
adam_beta2 = 0.999
|
||||
adam_epsilon = 1e-8
|
||||
label_smoothing_factor = 0.0
|
||||
|
||||
num_beams = 1
|
||||
val_max_target_length = None
|
||||
|
||||
predict_with_generate = True
|
||||
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(seed)
|
||||
rng, dropout_rng = jax.random.split(rng)
|
||||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||
)
|
||||
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||
return schedule_fn
|
||||
|
||||
|
||||
# Create learning rate schedule
|
||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||
len(test_dataset),
|
||||
train_batch_size,
|
||||
num_train_epochs,
|
||||
warmup_steps,
|
||||
learning_rate,
|
||||
)
|
||||
|
||||
# We use Optax's "masking" functionality to not apply weight decay
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = {
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
}
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=adam_beta1,
|
||||
b2=adam_beta2,
|
||||
eps=adam_epsilon,
|
||||
weight_decay=weight_decay,
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
|
||||
# %%
|
||||
|
||||
# reload model to prevent leakage of variables
|
||||
# load model
|
||||
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
||||
|
||||
# Load configuration
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
|
||||
# Load model
|
||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name_or_path
|
||||
)
|
||||
|
||||
# Training functions
|
||||
class TrainState(train_state.TrainState):
|
||||
dropout_rng: jnp.ndarray
|
||||
|
||||
def replicate(self):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
# Ensure model.params is properly initialized (this is just an example)
|
||||
# Normally you would get this from a model initialization call with dummy input
|
||||
params = model.params
|
||||
# Cast parameters to bfloat16 if desired
|
||||
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||
|
||||
|
||||
# Setup train state
|
||||
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
||||
|
||||
|
||||
|
||||
# Define generation function
|
||||
max_length = (
|
||||
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||
)
|
||||
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||
|
||||
def generate_step(params, batch):
|
||||
model.params = params
|
||||
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
||||
return output_ids.sequences
|
||||
|
||||
# Create parallel version of the train and eval step
|
||||
p_generate_step = jax.pmap(generate_step, "batch")
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = state.replicate()
|
||||
|
||||
|
||||
pred_metrics = []
|
||||
pred_generations = []
|
||||
pred_labels = []
|
||||
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
|
||||
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
|
||||
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
||||
|
||||
print("***** Running training *****")
|
||||
print(f" Num examples = {len(test_dataset)}")
|
||||
print(f" Num steps = {num_epochs}")
|
||||
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||||
|
||||
|
||||
for _ in tqdm(range(pred_steps), desc="Predicting...", position=0, leave=False):
|
||||
# Model forward
|
||||
batch = next(pred_loader)
|
||||
labels = batch["labels"]
|
||||
|
||||
# generation
|
||||
if predict_with_generate:
|
||||
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
||||
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||
pred_labels.extend(labels)
|
||||
|
||||
|
||||
|
||||
# Print metrics
|
||||
# desc = f"Predict Loss: {pred_metrics['loss']})"
|
||||
# print(desc)
|
||||
|
||||
# %%
|
||||
# save predictions to parquet
|
||||
|
||||
# decode prediction labels
|
||||
def decode_preds(preds):
|
||||
# In case the model returns more than the prediction logits
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
decoded_preds = [pred for pred in decoded_preds]
|
||||
|
||||
return decoded_preds
|
||||
|
||||
|
||||
# Convert the list to a Pandas DataFrame
|
||||
df = pd.DataFrame(decode_preds(pred_labels))
|
||||
|
||||
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
||||
df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
||||
|
||||
|
||||
|
||||
# %%
|
|
@ -0,0 +1,624 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.16.4
|
||||
# kernelspec:
|
||||
# display_name: jax
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
# # T5 implementation using jax
|
||||
|
||||
# %% [markdown]
|
||||
# ## import
|
||||
|
||||
# %% [raw]
|
||||
# import json
|
||||
# import logging
|
||||
# import math
|
||||
# import os
|
||||
# import sys
|
||||
# import time
|
||||
# from dataclasses import asdict, dataclass, field
|
||||
# from enum import Enum
|
||||
# from functools import partial
|
||||
# from pathlib import Path
|
||||
# from typing import Callable, Optional
|
||||
#
|
||||
# import datasets
|
||||
# import evaluate
|
||||
# import jax
|
||||
# import jax.numpy as jnp
|
||||
# import nltk # Here to have a nice missing dependency error message early on
|
||||
# import numpy as np
|
||||
# import optax
|
||||
# from datasets import Dataset, load_dataset
|
||||
# from filelock import FileLock
|
||||
# from flax import jax_utils, traverse_util
|
||||
# from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
# from flax.training import train_state
|
||||
# from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
# from tqdm import tqdm
|
||||
#
|
||||
# import transformers
|
||||
# from transformers import (
|
||||
# CONFIG_MAPPING,
|
||||
# FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
# AutoConfig,
|
||||
# AutoTokenizer,
|
||||
# FlaxAutoModelForSeq2SeqLM,
|
||||
# HfArgumentParser,
|
||||
# is_tensorboard_available,
|
||||
# )
|
||||
# from transformers.utils import is_offline_mode, send_example_telemetry
|
||||
#
|
||||
#
|
||||
# logger = logging.getLogger(__name__)
|
||||
#
|
||||
# try:
|
||||
# nltk.data.find("tokenizers/punkt")
|
||||
# except (LookupError, OSError):
|
||||
# if is_offline_mode():
|
||||
# raise LookupError(
|
||||
# "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
||||
# )
|
||||
# with FileLock(".lock") as lock:
|
||||
# nltk.download("punkt", quiet=True)
|
||||
#
|
||||
#
|
||||
# MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
||||
# MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
# %%
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
import math
|
||||
|
||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||
jax.config.update("jax_default_matmul_precision", "high")
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
|
||||
|
||||
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, load_dataset
|
||||
import evaluate
|
||||
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
# %%
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_gpu_enable_triton_softmax_fusion=True '
|
||||
'--xla_gpu_triton_gemm_any=True '
|
||||
)
|
||||
|
||||
os.environ.update({
|
||||
"CUDA_VISIBLE_DEVICES": "0, 1, 2, 3",
|
||||
"NCCL_LL128_BUFFSIZE": "-2",
|
||||
"NCCL_LL_BUFFSIZE": "-2",
|
||||
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||
})
|
||||
|
||||
# %%
|
||||
from jax.lib import xla_bridge
|
||||
print(xla_bridge.get_backend().platform)
|
||||
|
||||
|
||||
# %%
|
||||
# nltk.download('punkt')
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except (LookupError, OSError):
|
||||
if is_offline_mode():
|
||||
raise LookupError(
|
||||
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
||||
)
|
||||
with FileLock(".lock") as lock:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Prepare datasets
|
||||
|
||||
# %%
|
||||
# load model
|
||||
model_name_or_path = "t5-small" # Replace with your specific model name
|
||||
|
||||
# Load configuration
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
|
||||
# Load model
|
||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name_or_path
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||||
|
||||
|
||||
# %%
|
||||
from tqdm import tqdm
|
||||
from datasets import load_from_disk
|
||||
# Path to saved combined_dataset
|
||||
file_path = '/home/richard/Projects/learn_t5/retrieval/combined_data_t5'
|
||||
save_path = 't5_80_1_retrieval'
|
||||
# file_path = 'combined_data'
|
||||
split_datasets = load_from_disk(file_path)
|
||||
|
||||
# prepare tokenizer
|
||||
from transformers import T5TokenizerFast
|
||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||
# Define additional special tokens
|
||||
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||
# Define additional special tokens
|
||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>",
|
||||
"<CONTEXT>", "<EXAMPLE>", "<INPUT>", "<OUTPUT>"]
|
||||
# Add the additional special tokens to the tokenizer
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
|
||||
max_length = 300
|
||||
|
||||
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
||||
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
||||
# for that dynamically import the `shift_tokens_right` function from the model file
|
||||
|
||||
|
||||
# given a dataset entry, run it through the tokenizer
|
||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||
def preprocess_function(example):
|
||||
input = example['input']
|
||||
target = example['output']
|
||||
# text_target sets the corresponding label to inputs
|
||||
# there is no need to create a separate 'labels'
|
||||
model_inputs = tokenizer(
|
||||
input,
|
||||
text_target=target,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
labels = tokenizer(
|
||||
input,
|
||||
text_target=target,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np"
|
||||
)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
decoder_input_ids = shift_tokens_right_fn(
|
||||
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||||
)
|
||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||
|
||||
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||||
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
# map maps function to each "row" in the dataset
|
||||
# aka the data in the immediate nesting
|
||||
tokenized_datasets = split_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=1,
|
||||
remove_columns=split_datasets["train"].column_names,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
tokenized_datasets
|
||||
|
||||
# %%
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
eval_dataset = tokenized_datasets["validation"]
|
||||
|
||||
|
||||
# %%
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||
"""
|
||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||
"""
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.asarray(batch_idx)
|
||||
else:
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
if drop_last:
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
else:
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
yield batch
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# Now we have model inputs in terms of the variable tokenized_datasets
|
||||
|
||||
# %%
|
||||
# metric
|
||||
metric = evaluate.load("sacrebleu")
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
# rougeLSum expects newline after each sentence
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
# def compute_metrics(preds, labels):
|
||||
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
#
|
||||
# # Some simple post-processing
|
||||
# decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
#
|
||||
# result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||
# result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
# prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
# result["gen_len"] = np.mean(prediction_lens)
|
||||
# return result
|
||||
|
||||
def compute_metrics(preds, labels):
|
||||
# In case the model returns more than the prediction logits
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Replace -100s in the labels as we can't decode them
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
# Some simple post-processing
|
||||
decoded_preds = [pred.strip() for pred in decoded_preds]
|
||||
decoded_labels = [[label.strip()] for label in decoded_labels]
|
||||
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||
return {"bleu": result["score"]}
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# # Model
|
||||
|
||||
# %%
|
||||
# Store some constant
|
||||
seed = 117
|
||||
num_epochs = 80
|
||||
batch_size = 36
|
||||
num_train_epochs = num_epochs
|
||||
per_device_train_batch_size = batch_size
|
||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||
per_device_eval_batch_size = batch_size
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
warmup_steps = 0
|
||||
learning_rate = 5e-5
|
||||
|
||||
weight_decay = 0.0
|
||||
adam_beta1 = 0.9
|
||||
adam_beta2 = 0.999
|
||||
adam_epsilon = 1e-8
|
||||
label_smoothing_factor = 0.0
|
||||
|
||||
num_beams = 1
|
||||
val_max_target_length = None
|
||||
|
||||
predict_with_generate = True
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(seed)
|
||||
rng, dropout_rng = jax.random.split(rng)
|
||||
|
||||
|
||||
# %%
|
||||
# optimization functions
|
||||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||
)
|
||||
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||
return schedule_fn
|
||||
|
||||
|
||||
# Create learning rate schedule
|
||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||
len(train_dataset),
|
||||
train_batch_size,
|
||||
num_train_epochs,
|
||||
warmup_steps,
|
||||
learning_rate,
|
||||
)
|
||||
|
||||
# We use Optax's "masking" functionality to not apply weight decay
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
# find out all LayerNorm parameters
|
||||
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||
layer_norm_named_params = {
|
||||
layer[-2:]
|
||||
for layer_norm_name in layer_norm_candidates
|
||||
for layer in flat_params.keys()
|
||||
if layer_norm_name in "".join(layer).lower()
|
||||
}
|
||||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=adam_beta1,
|
||||
b2=adam_beta2,
|
||||
eps=adam_epsilon,
|
||||
weight_decay=weight_decay,
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Training functions
|
||||
class TrainState(train_state.TrainState):
|
||||
dropout_rng: jnp.ndarray
|
||||
|
||||
def replicate(self):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
# Ensure model.params is properly initialized (this is just an example)
|
||||
# Normally you would get this from a model initialization call with dummy input
|
||||
params = model.params
|
||||
# Cast parameters to bfloat16 if desired
|
||||
params_bf16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||
|
||||
|
||||
# Setup train state
|
||||
state = TrainState.create(apply_fn=model.__call__, params=params_bf16, tx=adamw, dropout_rng=dropout_rng)
|
||||
|
||||
# label smoothed cross entropy
|
||||
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||
"""
|
||||
The label smoothing implementation is adapted from Flax's official example:
|
||||
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||
"""
|
||||
vocab_size = logits.shape[-1]
|
||||
confidence = 1.0 - label_smoothing_factor
|
||||
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||
normalizing_constant = -(
|
||||
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||
)
|
||||
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||
|
||||
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||
loss = loss - normalizing_constant
|
||||
|
||||
# ignore padded tokens from loss
|
||||
loss = loss * padding_mask
|
||||
loss = loss.sum()
|
||||
num_labels = padding_mask.sum()
|
||||
return loss, num_labels
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||
|
||||
def compute_loss(params):
|
||||
labels = batch.pop("labels")
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss, num_labels
|
||||
|
||||
# compute gradients through computational graph
|
||||
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||
labels = batch.pop("labels")
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
metrics = {"loss": loss}
|
||||
return metrics
|
||||
|
||||
# Define generation function
|
||||
max_length = (
|
||||
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||
)
|
||||
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||
|
||||
def generate_step(params, batch):
|
||||
model.params = params
|
||||
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
||||
return output_ids.sequences
|
||||
|
||||
# Create parallel version of the train and eval step
|
||||
p_train_step = jax.pmap(
|
||||
partial(train_step, label_smoothing_factor=label_smoothing_factor), "batch", donate_argnums=(0,)
|
||||
)
|
||||
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=label_smoothing_factor), "batch")
|
||||
p_generate_step = jax.pmap(generate_step, "batch")
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = state.replicate()
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
print("***** Running training *****")
|
||||
print(f" Num examples = {len(train_dataset)}")
|
||||
print(f" Num Epochs = {num_epochs}")
|
||||
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||
print(f" Total optimization steps = {total_train_steps}")
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
# epochs = range(num_epochs)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
train_metrics = []
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
# train
|
||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||
batch = next(train_loader)
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
|
||||
train_metric = unreplicate(train_metric)
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||
f" {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
# eval_metrics = []
|
||||
# eval_preds = []
|
||||
# eval_labels = []
|
||||
|
||||
# eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
||||
# eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
||||
# for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# # Model forward
|
||||
# batch = next(eval_loader)
|
||||
# labels = batch["labels"]
|
||||
|
||||
# metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
# state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
# )
|
||||
# eval_metrics.append(metrics)
|
||||
|
||||
# # generation
|
||||
# if predict_with_generate:
|
||||
# generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
||||
# eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||
# eval_labels.extend(labels)
|
||||
|
||||
# # normalize eval metrics
|
||||
# eval_metrics = get_metrics(eval_metrics)
|
||||
# eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# compute metrics
|
||||
# rouge_desc = ""
|
||||
# if predict_with_generate:
|
||||
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
||||
# eval_metrics.update(rouge_metrics)
|
||||
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
||||
|
||||
# # Print metrics and update progress bar
|
||||
# desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
||||
# epochs.write(desc)
|
||||
# epochs.desc = desc
|
||||
|
||||
# Save metrics
|
||||
# if has_tensorboard and jax.process_index() == 0:
|
||||
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||
# write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
output_dir = save_path
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(output_dir, params=params)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# #
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue