learn_jax/parallel
Richard Wong a817fe16cc Feat: increased learning rate for effective large batch size learning 2024-09-22 22:28:41 +09:00
..
t5_model Feat: increased learning rate for effective large batch size learning 2024-09-22 22:28:41 +09:00
.gitignore Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
dataload.py Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
flax_pjit_tutorial.py Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
fully_sharded_data_parallelism.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
gpt-neo-125m.json Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
gptneo_partition_test.py Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
intro_to_distributed.py Feat: fsdp demo 2024-09-15 22:41:00 +09:00
partitions.py Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
single_gpu_optimizations.py Feat: fsdp demo 2024-09-15 22:41:00 +09:00
t5.json Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00
t5_jax_train_2.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
t5_jax_train_fail.py Feat: flax pjit example 2024-09-16 12:19:07 +09:00
t5_pjit.py Feat: t5_jax_simple_parallel implements a working example of fsdp 2024-09-20 23:42:51 +09:00