learn_jax/parallel
Richard Wong 429e1742ab Feat: flax pjit example 2024-09-16 12:19:07 +09:00
..
.gitignore Feat: fsdp demo 2024-09-15 22:41:00 +09:00
dataload.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
flax_pjit_tutorial.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
fully_sharded_data_parallelism.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
intro_to_distributed.py Feat: fsdp demo 2024-09-15 22:41:00 +09:00
single_gpu_optimizations.py Feat: fsdp demo 2024-09-15 22:41:00 +09:00
t5_jax_train_2.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
t5_jax_train_fail.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00