learn_jax/learn_flax/flax_basics.py

333 lines
9.6 KiB
Python

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.4
# ---
# %% [markdown]
# # Flax Basics
# %% [markdown]
# # Linear regression with Flax
# In this example we perform linear regression with gradient descent. It is
# done by first coming up with a model, generating data from the model (because
# data generated from known parameters will "match" the model predictions
# perfectly albeit with a little noise).
# %%
# import packages
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
# %%
# define model
# notice there is no input size declaration
model = nn.Dense(features=5)
# model paramaters & initialization
key1, key2 = random.split(random.key(0))
# dummy input to trigger shape inference
x = random.normal(key1, (10,))
params = model.init(key2, x) # initialization call
# check output shapes
print(jax.tree_util.tree_map(lambda x: x.shape, params))
model.apply(params, x)
# %%
# example of gradient descent
# set problem dimensions
n_samples: int = 20
x_dim: int = 10
y_dim: int = 5
# generate random ground truth W and b
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# store parameters in frozendict pytree
true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}})
print(jax.tree_util.tree_map(lambda x: x.shape, true_params))
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
# Wx + b + epsilon (noise)
y_samples = (
jnp.dot(x_samples, W) + b + 0.1 *
random.normal(key_noise, (n_samples, y_dim))
)
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
# %%
# get model output
@jax.jit
def mse(params, x_batched, y_batched):
# define squared loss for a single pair
def squared_error(x, y):
# use model.apply to get model_f(x)
pred = model.apply(params, x)
return jnp.inner(y - pred, y - pred) / 2.0
# vmap squared_error function across larger array along axis 0
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# %%
# perform gradient descent
learning_rate = 0.3 # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
# get the gradient function for backprop
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# Perform one gradient update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
# %% [markdown]
# # Optimization with Optax
# %%
# init optax object
import optax
tx = optax.adam(learning_rate=learning_rate)
# we initialize the optimizer with model params
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
# %%
# perform gradient descent with optax optimizer instead of update_params function
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss_val)
# %% [markdown]
# Serializing the result - aka saving the model
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
# %%
# load back from dict output
# we first get the param structure
model = nn.Dense(features=5)
key1, key2 = random.split(random.key(0))
# dummy input to trigger shape inference
x = random.normal(key1, (10,))
params = model.init(key2, x) # initialization call
# then we get the saved weights
# from state_dict
serialization.from_state_dict(params, dict_output)
print(params)
# from bytes
serialization.from_bytes(params, bytes_output)
print(params)
# %% [markdown]
# Defining your own models
# A nn.Module subclass is composed of:
#
# 1. collection of data fields (argument parameters)
# 2. setup() method to setup structures to be called in the forward pass
# 3. __call__() function that implements the forward pass
#
# there is a params structure for the model in the form of a pytree of
# parameters
# %%
# module basics
class ExplicitMLP(nn.Module):
# model arguments
# we have a list of feature sizes for the dense layer
features: Sequence[int]
def setup(self):
self.layers = [nn.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
# go through each layer
for i, lyr in enumerate(self.layers):
x = lyr(x)
# add an activating function at the end of each layer
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
# init the model
key1, key2 = random.split(random.key(117), 2)
# any shape is ok
x = random.uniform(key1, (4,4))
print(x)
# instantiate the model
model = ExplicitMLP(features=[3,4,5])
# init model with an rng key and a test data
params = model.init(key2, x)
y = model.apply(params, x)
print(
'initialized parameter shapes:\n',
jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)),
)
print('output:\n', y)
# %%
# using nn.Module with @nn.compact
class SimpleMLP(nn.Module):
features: Sequence[int]
# nn.compact style is declaring submodules inline
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
# providing a name is optional though!
# the default autonames would be "Dense_0", "Dense_1", ...
return x
key1, key2 = random.split(random.key(117), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
# %%
# creating your own layers
# this is a guide to show you how to create your own dense layer as an example
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
@nn.compact
def __call__(self, inputs):
# utilize the nn.Module.param function to declare a module
# * because params are in __call__, shape inference is possible *
# the 'W'
kernel = self.param(
'kernel', # name
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features), # shape
)
# 'Wx'
y = jnp.dot(inputs, kernel)
# 'Wx + b'
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleDense(features=3)
# init_fn takes (prng key, *init args, **init kwargs)
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameters:\n', params)
print('output:\n', y)
# %%
# handling state in your module
# jax does not work well with side-effects and hidden state
# jax functions are supposed to be pure functions with determinate input outputs
# we introduce "mutable" to separate the function and the state
# note: there are 2 collections:
# 'params' for trainable, 'batch_stats' for non-trainable
class BiasAdderWithRunningMean(nn.Module):
decay: float = 0.99
@nn.compact
def __call__(self, x):
# easy pattern to detect if we're initializing via empty variable tree
# check for variable 'mean' in the 'batch_stats' collection
is_initialized = self.has_variable('batch_stats', 'mean')
# declare variable outside model parameters
# declare variable mean in the 'batch_stats' collection
# ra_mean now means batch_stats->mean
ra_mean = self.variable(
'batch_stats',
'mean',
lambda s: jnp.zeros(s),
x.shape[1:] # dim of data after 0th element e.g. (10,5) -> 5
)
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
ra_mean.value = (
self.decay * ra_mean.value
+ (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
)
return x - ra_mean.value + bias
key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
# setting mutable means you also get the updated_state of mutable
# hence there is no true hidden state in the function
# this serves as a escape hatch for implementing pure functions with state
# state is therefore kept away from the module
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
# %%
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10, 5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
# %%