333 lines
9.6 KiB
Python
333 lines
9.6 KiB
Python
# ---
|
|
# jupyter:
|
|
# jupytext:
|
|
# formats: ipynb,py:percent
|
|
# text_representation:
|
|
# extension: .py
|
|
# format_name: percent
|
|
# format_version: '1.3'
|
|
# jupytext_version: 1.16.4
|
|
# ---
|
|
|
|
# %% [markdown]
|
|
# # Flax Basics
|
|
|
|
# %% [markdown]
|
|
# # Linear regression with Flax
|
|
# In this example we perform linear regression with gradient descent. It is
|
|
# done by first coming up with a model, generating data from the model (because
|
|
# data generated from known parameters will "match" the model predictions
|
|
# perfectly albeit with a little noise).
|
|
|
|
|
|
# %%
|
|
# import packages
|
|
import jax
|
|
from typing import Any, Callable, Sequence
|
|
from jax import random, numpy as jnp
|
|
import flax
|
|
from flax import linen as nn
|
|
|
|
# %%
|
|
# define model
|
|
# notice there is no input size declaration
|
|
model = nn.Dense(features=5)
|
|
|
|
# model paramaters & initialization
|
|
key1, key2 = random.split(random.key(0))
|
|
# dummy input to trigger shape inference
|
|
x = random.normal(key1, (10,))
|
|
params = model.init(key2, x) # initialization call
|
|
|
|
# check output shapes
|
|
print(jax.tree_util.tree_map(lambda x: x.shape, params))
|
|
|
|
model.apply(params, x)
|
|
|
|
|
|
# %%
|
|
# example of gradient descent
|
|
|
|
# set problem dimensions
|
|
n_samples: int = 20
|
|
x_dim: int = 10
|
|
y_dim: int = 5
|
|
|
|
# generate random ground truth W and b
|
|
key = random.key(0)
|
|
k1, k2 = random.split(key)
|
|
W = random.normal(k1, (x_dim, y_dim))
|
|
b = random.normal(k2, (y_dim,))
|
|
# store parameters in frozendict pytree
|
|
true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}})
|
|
print(jax.tree_util.tree_map(lambda x: x.shape, true_params))
|
|
|
|
# Generate samples with additional noise.
|
|
key_sample, key_noise = random.split(k1)
|
|
x_samples = random.normal(key_sample, (n_samples, x_dim))
|
|
# Wx + b + epsilon (noise)
|
|
y_samples = (
|
|
jnp.dot(x_samples, W) + b + 0.1 *
|
|
random.normal(key_noise, (n_samples, y_dim))
|
|
)
|
|
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
|
|
|
|
|
|
# %%
|
|
# get model output
|
|
@jax.jit
|
|
def mse(params, x_batched, y_batched):
|
|
# define squared loss for a single pair
|
|
def squared_error(x, y):
|
|
# use model.apply to get model_f(x)
|
|
pred = model.apply(params, x)
|
|
return jnp.inner(y - pred, y - pred) / 2.0
|
|
# vmap squared_error function across larger array along axis 0
|
|
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
|
|
|
|
|
|
# %%
|
|
|
|
# perform gradient descent
|
|
learning_rate = 0.3 # Gradient step size.
|
|
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
|
|
# get the gradient function for backprop
|
|
loss_grad_fn = jax.value_and_grad(mse)
|
|
|
|
|
|
@jax.jit
|
|
def update_params(params, learning_rate, grads):
|
|
params = jax.tree_util.tree_map(
|
|
lambda p, g: p - learning_rate * g, params, grads)
|
|
return params
|
|
|
|
|
|
for i in range(101):
|
|
# Perform one gradient update.
|
|
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
|
|
params = update_params(params, learning_rate, grads)
|
|
if i % 10 == 0:
|
|
print(f'Loss step {i}: ', loss_val)
|
|
|
|
# %% [markdown]
|
|
# # Optimization with Optax
|
|
|
|
# %%
|
|
# init optax object
|
|
import optax
|
|
tx = optax.adam(learning_rate=learning_rate)
|
|
# we initialize the optimizer with model params
|
|
opt_state = tx.init(params)
|
|
loss_grad_fn = jax.value_and_grad(mse)
|
|
|
|
# %%
|
|
# perform gradient descent with optax optimizer instead of update_params function
|
|
for i in range(101):
|
|
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
|
|
updates, opt_state = tx.update(grads, opt_state)
|
|
params = optax.apply_updates(params, updates)
|
|
if i % 10 == 0:
|
|
print('Loss step {}: '.format(i), loss_val)
|
|
|
|
# %% [markdown]
|
|
# Serializing the result - aka saving the model
|
|
from flax import serialization
|
|
bytes_output = serialization.to_bytes(params)
|
|
dict_output = serialization.to_state_dict(params)
|
|
print('Dict output')
|
|
print(dict_output)
|
|
print('Bytes output')
|
|
print(bytes_output)
|
|
|
|
|
|
# %%
|
|
# load back from dict output
|
|
# we first get the param structure
|
|
model = nn.Dense(features=5)
|
|
key1, key2 = random.split(random.key(0))
|
|
# dummy input to trigger shape inference
|
|
x = random.normal(key1, (10,))
|
|
params = model.init(key2, x) # initialization call
|
|
|
|
# then we get the saved weights
|
|
# from state_dict
|
|
serialization.from_state_dict(params, dict_output)
|
|
print(params)
|
|
# from bytes
|
|
serialization.from_bytes(params, bytes_output)
|
|
print(params)
|
|
|
|
|
|
# %% [markdown]
|
|
# Defining your own models
|
|
# A nn.Module subclass is composed of:
|
|
#
|
|
# 1. collection of data fields (argument parameters)
|
|
# 2. setup() method to setup structures to be called in the forward pass
|
|
# 3. __call__() function that implements the forward pass
|
|
#
|
|
# there is a params structure for the model in the form of a pytree of
|
|
# parameters
|
|
|
|
# %%
|
|
# module basics
|
|
class ExplicitMLP(nn.Module):
|
|
# model arguments
|
|
# we have a list of feature sizes for the dense layer
|
|
features: Sequence[int]
|
|
|
|
def setup(self):
|
|
self.layers = [nn.Dense(feat) for feat in self.features]
|
|
|
|
def __call__(self, inputs):
|
|
x = inputs
|
|
# go through each layer
|
|
for i, lyr in enumerate(self.layers):
|
|
x = lyr(x)
|
|
# add an activating function at the end of each layer
|
|
if i != len(self.layers) - 1:
|
|
x = nn.relu(x)
|
|
return x
|
|
|
|
# init the model
|
|
key1, key2 = random.split(random.key(117), 2)
|
|
# any shape is ok
|
|
x = random.uniform(key1, (4,4))
|
|
print(x)
|
|
|
|
# instantiate the model
|
|
model = ExplicitMLP(features=[3,4,5])
|
|
# init model with an rng key and a test data
|
|
params = model.init(key2, x)
|
|
y = model.apply(params, x)
|
|
|
|
print(
|
|
'initialized parameter shapes:\n',
|
|
jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)),
|
|
)
|
|
print('output:\n', y)
|
|
|
|
# %%
|
|
# using nn.Module with @nn.compact
|
|
|
|
class SimpleMLP(nn.Module):
|
|
features: Sequence[int]
|
|
|
|
# nn.compact style is declaring submodules inline
|
|
@nn.compact
|
|
def __call__(self, inputs):
|
|
x = inputs
|
|
for i, feat in enumerate(self.features):
|
|
x = nn.Dense(feat, name=f'layers_{i}')(x)
|
|
if i != len(self.features) - 1:
|
|
x = nn.relu(x)
|
|
# providing a name is optional though!
|
|
# the default autonames would be "Dense_0", "Dense_1", ...
|
|
return x
|
|
|
|
key1, key2 = random.split(random.key(117), 2)
|
|
x = random.uniform(key1, (4,4))
|
|
|
|
model = SimpleMLP(features=[3,4,5])
|
|
params = model.init(key2, x)
|
|
y = model.apply(params, x)
|
|
|
|
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
|
|
print('output:\n', y)
|
|
|
|
|
|
# %%
|
|
# creating your own layers
|
|
# this is a guide to show you how to create your own dense layer as an example
|
|
class SimpleDense(nn.Module):
|
|
features: int
|
|
kernel_init: Callable = nn.initializers.lecun_normal()
|
|
bias_init: Callable = nn.initializers.zeros_init()
|
|
|
|
@nn.compact
|
|
def __call__(self, inputs):
|
|
# utilize the nn.Module.param function to declare a module
|
|
# * because params are in __call__, shape inference is possible *
|
|
# the 'W'
|
|
kernel = self.param(
|
|
'kernel', # name
|
|
self.kernel_init, # Initialization function
|
|
(inputs.shape[-1], self.features), # shape
|
|
)
|
|
# 'Wx'
|
|
y = jnp.dot(inputs, kernel)
|
|
# 'Wx + b'
|
|
bias = self.param('bias', self.bias_init, (self.features,))
|
|
y = y + bias
|
|
return y
|
|
|
|
key1, key2 = random.split(random.key(0), 2)
|
|
x = random.uniform(key1, (4,4))
|
|
|
|
model = SimpleDense(features=3)
|
|
# init_fn takes (prng key, *init args, **init kwargs)
|
|
params = model.init(key2, x)
|
|
y = model.apply(params, x)
|
|
|
|
print('initialized parameters:\n', params)
|
|
print('output:\n', y)
|
|
|
|
|
|
# %%
|
|
# handling state in your module
|
|
# jax does not work well with side-effects and hidden state
|
|
# jax functions are supposed to be pure functions with determinate input outputs
|
|
# we introduce "mutable" to separate the function and the state
|
|
|
|
# note: there are 2 collections:
|
|
# 'params' for trainable, 'batch_stats' for non-trainable
|
|
class BiasAdderWithRunningMean(nn.Module):
|
|
decay: float = 0.99
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
# easy pattern to detect if we're initializing via empty variable tree
|
|
# check for variable 'mean' in the 'batch_stats' collection
|
|
is_initialized = self.has_variable('batch_stats', 'mean')
|
|
# declare variable outside model parameters
|
|
# declare variable mean in the 'batch_stats' collection
|
|
# ra_mean now means batch_stats->mean
|
|
ra_mean = self.variable(
|
|
'batch_stats',
|
|
'mean',
|
|
lambda s: jnp.zeros(s),
|
|
x.shape[1:] # dim of data after 0th element e.g. (10,5) -> 5
|
|
)
|
|
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
|
|
if is_initialized:
|
|
ra_mean.value = (
|
|
self.decay * ra_mean.value
|
|
+ (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
|
|
)
|
|
|
|
return x - ra_mean.value + bias
|
|
|
|
|
|
key1, key2 = random.split(random.key(0), 2)
|
|
x = jnp.ones((10,5))
|
|
model = BiasAdderWithRunningMean()
|
|
variables = model.init(key1, x)
|
|
print('initialized variables:\n', variables)
|
|
# setting mutable means you also get the updated_state of mutable
|
|
# hence there is no true hidden state in the function
|
|
# this serves as a escape hatch for implementing pure functions with state
|
|
# state is therefore kept away from the module
|
|
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
|
|
print('updated state:\n', updated_state)
|
|
|
|
|
|
# %%
|
|
for val in [1.0, 2.0, 3.0]:
|
|
x = val * jnp.ones((10, 5))
|
|
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
|
|
old_state, params = flax.core.pop(variables, 'params')
|
|
variables = flax.core.freeze({'params': params, **updated_state})
|
|
print('updated state:\n', updated_state) # Shows only the mutable part
|
|
|
|
# %%
|