Feat: learn flax
This commit is contained in:
parent
d2dd72227f
commit
005a1a5735
|
@ -3,3 +3,5 @@ t5_*/
|
|||
exports/
|
||||
modified_t5_model/
|
||||
traces/
|
||||
ruff.toml
|
||||
settings.json
|
||||
|
|
|
@ -0,0 +1,332 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.16.4
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
# # Flax Basics
|
||||
|
||||
# %% [markdown]
|
||||
# # Linear regression with Flax
|
||||
# In this example we perform linear regression with gradient descent. It is
|
||||
# done by first coming up with a model, generating data from the model (because
|
||||
# data generated from known parameters will "match" the model predictions
|
||||
# perfectly albeit with a little noise).
|
||||
|
||||
|
||||
# %%
|
||||
# import packages
|
||||
import jax
|
||||
from typing import Any, Callable, Sequence
|
||||
from jax import random, numpy as jnp
|
||||
import flax
|
||||
from flax import linen as nn
|
||||
|
||||
# %%
|
||||
# define model
|
||||
# notice there is no input size declaration
|
||||
model = nn.Dense(features=5)
|
||||
|
||||
# model paramaters & initialization
|
||||
key1, key2 = random.split(random.key(0))
|
||||
# dummy input to trigger shape inference
|
||||
x = random.normal(key1, (10,))
|
||||
params = model.init(key2, x) # initialization call
|
||||
|
||||
# check output shapes
|
||||
print(jax.tree_util.tree_map(lambda x: x.shape, params))
|
||||
|
||||
model.apply(params, x)
|
||||
|
||||
|
||||
# %%
|
||||
# example of gradient descent
|
||||
|
||||
# set problem dimensions
|
||||
n_samples: int = 20
|
||||
x_dim: int = 10
|
||||
y_dim: int = 5
|
||||
|
||||
# generate random ground truth W and b
|
||||
key = random.key(0)
|
||||
k1, k2 = random.split(key)
|
||||
W = random.normal(k1, (x_dim, y_dim))
|
||||
b = random.normal(k2, (y_dim,))
|
||||
# store parameters in frozendict pytree
|
||||
true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}})
|
||||
print(jax.tree_util.tree_map(lambda x: x.shape, true_params))
|
||||
|
||||
# Generate samples with additional noise.
|
||||
key_sample, key_noise = random.split(k1)
|
||||
x_samples = random.normal(key_sample, (n_samples, x_dim))
|
||||
# Wx + b + epsilon (noise)
|
||||
y_samples = (
|
||||
jnp.dot(x_samples, W) + b + 0.1 *
|
||||
random.normal(key_noise, (n_samples, y_dim))
|
||||
)
|
||||
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
|
||||
|
||||
|
||||
# %%
|
||||
# get model output
|
||||
@jax.jit
|
||||
def mse(params, x_batched, y_batched):
|
||||
# define squared loss for a single pair
|
||||
def squared_error(x, y):
|
||||
# use model.apply to get model_f(x)
|
||||
pred = model.apply(params, x)
|
||||
return jnp.inner(y - pred, y - pred) / 2.0
|
||||
# vmap squared_error function across larger array along axis 0
|
||||
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# perform gradient descent
|
||||
learning_rate = 0.3 # Gradient step size.
|
||||
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
|
||||
# get the gradient function for backprop
|
||||
loss_grad_fn = jax.value_and_grad(mse)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def update_params(params, learning_rate, grads):
|
||||
params = jax.tree_util.tree_map(
|
||||
lambda p, g: p - learning_rate * g, params, grads)
|
||||
return params
|
||||
|
||||
|
||||
for i in range(101):
|
||||
# Perform one gradient update.
|
||||
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
|
||||
params = update_params(params, learning_rate, grads)
|
||||
if i % 10 == 0:
|
||||
print(f'Loss step {i}: ', loss_val)
|
||||
|
||||
# %% [markdown]
|
||||
# # Optimization with Optax
|
||||
|
||||
# %%
|
||||
# init optax object
|
||||
import optax
|
||||
tx = optax.adam(learning_rate=learning_rate)
|
||||
# we initialize the optimizer with model params
|
||||
opt_state = tx.init(params)
|
||||
loss_grad_fn = jax.value_and_grad(mse)
|
||||
|
||||
# %%
|
||||
# perform gradient descent with optax optimizer instead of update_params function
|
||||
for i in range(101):
|
||||
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
|
||||
updates, opt_state = tx.update(grads, opt_state)
|
||||
params = optax.apply_updates(params, updates)
|
||||
if i % 10 == 0:
|
||||
print('Loss step {}: '.format(i), loss_val)
|
||||
|
||||
# %% [markdown]
|
||||
# Serializing the result - aka saving the model
|
||||
from flax import serialization
|
||||
bytes_output = serialization.to_bytes(params)
|
||||
dict_output = serialization.to_state_dict(params)
|
||||
print('Dict output')
|
||||
print(dict_output)
|
||||
print('Bytes output')
|
||||
print(bytes_output)
|
||||
|
||||
|
||||
# %%
|
||||
# load back from dict output
|
||||
# we first get the param structure
|
||||
model = nn.Dense(features=5)
|
||||
key1, key2 = random.split(random.key(0))
|
||||
# dummy input to trigger shape inference
|
||||
x = random.normal(key1, (10,))
|
||||
params = model.init(key2, x) # initialization call
|
||||
|
||||
# then we get the saved weights
|
||||
# from state_dict
|
||||
serialization.from_state_dict(params, dict_output)
|
||||
print(params)
|
||||
# from bytes
|
||||
serialization.from_bytes(params, bytes_output)
|
||||
print(params)
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# Defining your own models
|
||||
# A nn.Module subclass is composed of:
|
||||
#
|
||||
# 1. collection of data fields (argument parameters)
|
||||
# 2. setup() method to setup structures to be called in the forward pass
|
||||
# 3. __call__() function that implements the forward pass
|
||||
#
|
||||
# there is a params structure for the model in the form of a pytree of
|
||||
# parameters
|
||||
|
||||
# %%
|
||||
# module basics
|
||||
class ExplicitMLP(nn.Module):
|
||||
# model arguments
|
||||
# we have a list of feature sizes for the dense layer
|
||||
features: Sequence[int]
|
||||
|
||||
def setup(self):
|
||||
self.layers = [nn.Dense(feat) for feat in self.features]
|
||||
|
||||
def __call__(self, inputs):
|
||||
x = inputs
|
||||
# go through each layer
|
||||
for i, lyr in enumerate(self.layers):
|
||||
x = lyr(x)
|
||||
# add an activating function at the end of each layer
|
||||
if i != len(self.layers) - 1:
|
||||
x = nn.relu(x)
|
||||
return x
|
||||
|
||||
# init the model
|
||||
key1, key2 = random.split(random.key(117), 2)
|
||||
# any shape is ok
|
||||
x = random.uniform(key1, (4,4))
|
||||
print(x)
|
||||
|
||||
# instantiate the model
|
||||
model = ExplicitMLP(features=[3,4,5])
|
||||
# init model with an rng key and a test data
|
||||
params = model.init(key2, x)
|
||||
y = model.apply(params, x)
|
||||
|
||||
print(
|
||||
'initialized parameter shapes:\n',
|
||||
jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)),
|
||||
)
|
||||
print('output:\n', y)
|
||||
|
||||
# %%
|
||||
# using nn.Module with @nn.compact
|
||||
|
||||
class SimpleMLP(nn.Module):
|
||||
features: Sequence[int]
|
||||
|
||||
# nn.compact style is declaring submodules inline
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
x = inputs
|
||||
for i, feat in enumerate(self.features):
|
||||
x = nn.Dense(feat, name=f'layers_{i}')(x)
|
||||
if i != len(self.features) - 1:
|
||||
x = nn.relu(x)
|
||||
# providing a name is optional though!
|
||||
# the default autonames would be "Dense_0", "Dense_1", ...
|
||||
return x
|
||||
|
||||
key1, key2 = random.split(random.key(117), 2)
|
||||
x = random.uniform(key1, (4,4))
|
||||
|
||||
model = SimpleMLP(features=[3,4,5])
|
||||
params = model.init(key2, x)
|
||||
y = model.apply(params, x)
|
||||
|
||||
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
|
||||
print('output:\n', y)
|
||||
|
||||
|
||||
# %%
|
||||
# creating your own layers
|
||||
# this is a guide to show you how to create your own dense layer as an example
|
||||
class SimpleDense(nn.Module):
|
||||
features: int
|
||||
kernel_init: Callable = nn.initializers.lecun_normal()
|
||||
bias_init: Callable = nn.initializers.zeros_init()
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
# utilize the nn.Module.param function to declare a module
|
||||
# * because params are in __call__, shape inference is possible *
|
||||
# the 'W'
|
||||
kernel = self.param(
|
||||
'kernel', # name
|
||||
self.kernel_init, # Initialization function
|
||||
(inputs.shape[-1], self.features), # shape
|
||||
)
|
||||
# 'Wx'
|
||||
y = jnp.dot(inputs, kernel)
|
||||
# 'Wx + b'
|
||||
bias = self.param('bias', self.bias_init, (self.features,))
|
||||
y = y + bias
|
||||
return y
|
||||
|
||||
key1, key2 = random.split(random.key(0), 2)
|
||||
x = random.uniform(key1, (4,4))
|
||||
|
||||
model = SimpleDense(features=3)
|
||||
# init_fn takes (prng key, *init args, **init kwargs)
|
||||
params = model.init(key2, x)
|
||||
y = model.apply(params, x)
|
||||
|
||||
print('initialized parameters:\n', params)
|
||||
print('output:\n', y)
|
||||
|
||||
|
||||
# %%
|
||||
# handling state in your module
|
||||
# jax does not work well with side-effects and hidden state
|
||||
# jax functions are supposed to be pure functions with determinate input outputs
|
||||
# we introduce "mutable" to separate the function and the state
|
||||
|
||||
# note: there are 2 collections:
|
||||
# 'params' for trainable, 'batch_stats' for non-trainable
|
||||
class BiasAdderWithRunningMean(nn.Module):
|
||||
decay: float = 0.99
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
# easy pattern to detect if we're initializing via empty variable tree
|
||||
# check for variable 'mean' in the 'batch_stats' collection
|
||||
is_initialized = self.has_variable('batch_stats', 'mean')
|
||||
# declare variable outside model parameters
|
||||
# declare variable mean in the 'batch_stats' collection
|
||||
# ra_mean now means batch_stats->mean
|
||||
ra_mean = self.variable(
|
||||
'batch_stats',
|
||||
'mean',
|
||||
lambda s: jnp.zeros(s),
|
||||
x.shape[1:] # dim of data after 0th element e.g. (10,5) -> 5
|
||||
)
|
||||
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
|
||||
if is_initialized:
|
||||
ra_mean.value = (
|
||||
self.decay * ra_mean.value
|
||||
+ (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
|
||||
)
|
||||
|
||||
return x - ra_mean.value + bias
|
||||
|
||||
|
||||
key1, key2 = random.split(random.key(0), 2)
|
||||
x = jnp.ones((10,5))
|
||||
model = BiasAdderWithRunningMean()
|
||||
variables = model.init(key1, x)
|
||||
print('initialized variables:\n', variables)
|
||||
# setting mutable means you also get the updated_state of mutable
|
||||
# hence there is no true hidden state in the function
|
||||
# this serves as a escape hatch for implementing pure functions with state
|
||||
# state is therefore kept away from the module
|
||||
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
|
||||
print('updated state:\n', updated_state)
|
||||
|
||||
|
||||
# %%
|
||||
for val in [1.0, 2.0, 3.0]:
|
||||
x = val * jnp.ones((10, 5))
|
||||
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
|
||||
old_state, params = flax.core.pop(variables, 'params')
|
||||
variables = flax.core.freeze({'params': params, **updated_state})
|
||||
print('updated state:\n', updated_state) # Shows only the mutable part
|
||||
|
||||
# %%
|
32
t5_jax.py
32
t5_jax.py
|
@ -96,7 +96,7 @@ split_datasets = load_from_disk(file_path)
|
|||
training_size = len(split_datasets['train'])
|
||||
# Store some constant
|
||||
seed = 117
|
||||
num_epochs = 80
|
||||
num_epochs = 5
|
||||
batch_size = 384 # 384 is the best
|
||||
num_train_epochs = num_epochs
|
||||
per_device_train_batch_size = batch_size
|
||||
|
@ -314,16 +314,16 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
||||
|
||||
yield batch
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# # Model
|
||||
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
#
|
||||
|
||||
# %%
|
||||
|
||||
|
@ -442,16 +442,17 @@ def train_step(state, batch, label_smoothing_factor=0.0):
|
|||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
# loss = jax.lax.psum(loss, "batch")
|
||||
# loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
return new_state, metrics
|
||||
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
# return new_state, metrics
|
||||
return new_state
|
||||
|
||||
# Define generation function
|
||||
max_length = (
|
||||
|
@ -505,19 +506,20 @@ for epoch in epochs:
|
|||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||
batch = next(train_loader)
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
state = p_train_step(state, batch)
|
||||
# train_metrics.append(train_metric)
|
||||
|
||||
train_time = time.time() - train_start
|
||||
|
||||
train_metric = unreplicate(train_metric)
|
||||
train_metric['loss'].block_until_ready()
|
||||
# train_metric = unreplicate(train_metric)
|
||||
# train_metric['loss'].block_until_ready()
|
||||
|
||||
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
||||
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||
# f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, "
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||
# f"Learning Rate:{train_metric['learning_rate']}, "
|
||||
f"Last train time: {train_time})"
|
||||
)
|
||||
# jax.profiler.stop_trace()
|
||||
|
|
|
@ -30,7 +30,7 @@ from typing import Callable, Optional
|
|||
import math
|
||||
|
||||
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||
jax.config.update("jax_default_matmul_precision", "high")
|
||||
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||
|
||||
jax.config.update("jax_enable_x64", False)
|
||||
|
||||
|
|
Loading…
Reference in New Issue