Feat: learn flax

This commit is contained in:
Richard Wong 2024-09-14 14:13:38 +09:00
parent d2dd72227f
commit 005a1a5735
4 changed files with 352 additions and 16 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ t5_*/
exports/ exports/
modified_t5_model/ modified_t5_model/
traces/ traces/
ruff.toml
settings.json

332
learn_flax/flax_basics.py Normal file
View File

@ -0,0 +1,332 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # Flax Basics
# %% [markdown]
# # Linear regression with Flax
# In this example we perform linear regression with gradient descent. It is
# done by first coming up with a model, generating data from the model (because
# data generated from known parameters will "match" the model predictions
# perfectly albeit with a little noise).
# %%
# import packages
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
# %%
# define model
# notice there is no input size declaration
model = nn.Dense(features=5)
# model paramaters & initialization
key1, key2 = random.split(random.key(0))
# dummy input to trigger shape inference
x = random.normal(key1, (10,))
params = model.init(key2, x) # initialization call
# check output shapes
print(jax.tree_util.tree_map(lambda x: x.shape, params))
model.apply(params, x)
# %%
# example of gradient descent
# set problem dimensions
n_samples: int = 20
x_dim: int = 10
y_dim: int = 5
# generate random ground truth W and b
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# store parameters in frozendict pytree
true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}})
print(jax.tree_util.tree_map(lambda x: x.shape, true_params))
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
# Wx + b + epsilon (noise)
y_samples = (
jnp.dot(x_samples, W) + b + 0.1 *
random.normal(key_noise, (n_samples, y_dim))
)
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
# %%
# get model output
@jax.jit
def mse(params, x_batched, y_batched):
# define squared loss for a single pair
def squared_error(x, y):
# use model.apply to get model_f(x)
pred = model.apply(params, x)
return jnp.inner(y - pred, y - pred) / 2.0
# vmap squared_error function across larger array along axis 0
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# %%
# perform gradient descent
learning_rate = 0.3 # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
# get the gradient function for backprop
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# Perform one gradient update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
# %% [markdown]
# # Optimization with Optax
# %%
# init optax object
import optax
tx = optax.adam(learning_rate=learning_rate)
# we initialize the optimizer with model params
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
# %%
# perform gradient descent with optax optimizer instead of update_params function
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss_val)
# %% [markdown]
# Serializing the result - aka saving the model
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
# %%
# load back from dict output
# we first get the param structure
model = nn.Dense(features=5)
key1, key2 = random.split(random.key(0))
# dummy input to trigger shape inference
x = random.normal(key1, (10,))
params = model.init(key2, x) # initialization call
# then we get the saved weights
# from state_dict
serialization.from_state_dict(params, dict_output)
print(params)
# from bytes
serialization.from_bytes(params, bytes_output)
print(params)
# %% [markdown]
# Defining your own models
# A nn.Module subclass is composed of:
#
# 1. collection of data fields (argument parameters)
# 2. setup() method to setup structures to be called in the forward pass
# 3. __call__() function that implements the forward pass
#
# there is a params structure for the model in the form of a pytree of
# parameters
# %%
# module basics
class ExplicitMLP(nn.Module):
# model arguments
# we have a list of feature sizes for the dense layer
features: Sequence[int]
def setup(self):
self.layers = [nn.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
# go through each layer
for i, lyr in enumerate(self.layers):
x = lyr(x)
# add an activating function at the end of each layer
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
# init the model
key1, key2 = random.split(random.key(117), 2)
# any shape is ok
x = random.uniform(key1, (4,4))
print(x)
# instantiate the model
model = ExplicitMLP(features=[3,4,5])
# init model with an rng key and a test data
params = model.init(key2, x)
y = model.apply(params, x)
print(
'initialized parameter shapes:\n',
jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)),
)
print('output:\n', y)
# %%
# using nn.Module with @nn.compact
class SimpleMLP(nn.Module):
features: Sequence[int]
# nn.compact style is declaring submodules inline
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
# providing a name is optional though!
# the default autonames would be "Dense_0", "Dense_1", ...
return x
key1, key2 = random.split(random.key(117), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
# %%
# creating your own layers
# this is a guide to show you how to create your own dense layer as an example
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
@nn.compact
def __call__(self, inputs):
# utilize the nn.Module.param function to declare a module
# * because params are in __call__, shape inference is possible *
# the 'W'
kernel = self.param(
'kernel', # name
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features), # shape
)
# 'Wx'
y = jnp.dot(inputs, kernel)
# 'Wx + b'
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleDense(features=3)
# init_fn takes (prng key, *init args, **init kwargs)
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameters:\n', params)
print('output:\n', y)
# %%
# handling state in your module
# jax does not work well with side-effects and hidden state
# jax functions are supposed to be pure functions with determinate input outputs
# we introduce "mutable" to separate the function and the state
# note: there are 2 collections:
# 'params' for trainable, 'batch_stats' for non-trainable
class BiasAdderWithRunningMean(nn.Module):
decay: float = 0.99
@nn.compact
def __call__(self, x):
# easy pattern to detect if we're initializing via empty variable tree
# check for variable 'mean' in the 'batch_stats' collection
is_initialized = self.has_variable('batch_stats', 'mean')
# declare variable outside model parameters
# declare variable mean in the 'batch_stats' collection
# ra_mean now means batch_stats->mean
ra_mean = self.variable(
'batch_stats',
'mean',
lambda s: jnp.zeros(s),
x.shape[1:] # dim of data after 0th element e.g. (10,5) -> 5
)
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
ra_mean.value = (
self.decay * ra_mean.value
+ (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
)
return x - ra_mean.value + bias
key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
# setting mutable means you also get the updated_state of mutable
# hence there is no true hidden state in the function
# this serves as a escape hatch for implementing pure functions with state
# state is therefore kept away from the module
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
# %%
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10, 5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
# %%

View File

@ -96,7 +96,7 @@ split_datasets = load_from_disk(file_path)
training_size = len(split_datasets['train']) training_size = len(split_datasets['train'])
# Store some constant # Store some constant
seed = 117 seed = 117
num_epochs = 80 num_epochs = 5
batch_size = 384 # 384 is the best batch_size = 384 # 384 is the best
num_train_epochs = num_epochs num_train_epochs = num_epochs
per_device_train_batch_size = batch_size per_device_train_batch_size = batch_size
@ -314,16 +314,16 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
for idx in batch_idx: for idx in batch_idx:
batch = dataset[idx] batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()} batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch yield batch
# %% [markdown] # %% [markdown]
# # Model # # Model
#
#
#
# %% # %%
@ -442,16 +442,17 @@ def train_step(state, batch, label_smoothing_factor=0.0):
num_labels = jax.lax.psum(num_labels, "batch") num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples # true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch") # loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples # true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch") grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics # return new_state, metrics
return new_state
# Define generation function # Define generation function
max_length = ( max_length = (
@ -505,19 +506,20 @@ for epoch in epochs:
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader) batch = next(train_loader)
batch = shard(batch) batch = shard(batch)
state, train_metric = p_train_step(state, batch) state = p_train_step(state, batch)
train_metrics.append(train_metric) # train_metrics.append(train_metric)
train_time = time.time() - train_start train_time = time.time() - train_start
train_metric = unreplicate(train_metric) # train_metric = unreplicate(train_metric)
train_metric['loss'].block_until_ready() # train_metric['loss'].block_until_ready()
epochs.write( epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
f"Learning Rate:{train_metric['learning_rate']}, " f"Epoch... ({epoch + 1}/{num_epochs} | "
# f"Learning Rate:{train_metric['learning_rate']}, "
f"Last train time: {train_time})" f"Last train time: {train_time})"
) )
# jax.profiler.stop_trace() # jax.profiler.stop_trace()

View File

@ -30,7 +30,7 @@ from typing import Callable, Optional
import math import math
# jax.config.update("jax_default_matmul_precision", "tensorfloat32") # jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "high") jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_enable_x64", False) jax.config.update("jax_enable_x64", False)