# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.16.4 # --- # %% [markdown] # # Flax Basics # %% [markdown] # # Linear regression with Flax # In this example we perform linear regression with gradient descent. It is # done by first coming up with a model, generating data from the model (because # data generated from known parameters will "match" the model predictions # perfectly albeit with a little noise). # %% # import packages import jax from typing import Any, Callable, Sequence from jax import random, numpy as jnp import flax from flax import linen as nn # %% # define model # notice there is no input size declaration model = nn.Dense(features=5) # model paramaters & initialization key1, key2 = random.split(random.key(0)) # dummy input to trigger shape inference x = random.normal(key1, (10,)) params = model.init(key2, x) # initialization call # check output shapes print(jax.tree_util.tree_map(lambda x: x.shape, params)) model.apply(params, x) # %% # example of gradient descent # set problem dimensions n_samples: int = 20 x_dim: int = 10 y_dim: int = 5 # generate random ground truth W and b key = random.key(0) k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # store parameters in frozendict pytree true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}}) print(jax.tree_util.tree_map(lambda x: x.shape, true_params)) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) x_samples = random.normal(key_sample, (n_samples, x_dim)) # Wx + b + epsilon (noise) y_samples = ( jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim)) ) print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) # %% # get model output @jax.jit def mse(params, x_batched, y_batched): # define squared loss for a single pair def squared_error(x, y): # use model.apply to get model_f(x) pred = model.apply(params, x) return jnp.inner(y - pred, y - pred) / 2.0 # vmap squared_error function across larger array along axis 0 return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) # %% # perform gradient descent learning_rate = 0.3 # Gradient step size. print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples)) # get the gradient function for backprop loss_grad_fn = jax.value_and_grad(mse) @jax.jit def update_params(params, learning_rate, grads): params = jax.tree_util.tree_map( lambda p, g: p - learning_rate * g, params, grads) return params for i in range(101): # Perform one gradient update. loss_val, grads = loss_grad_fn(params, x_samples, y_samples) params = update_params(params, learning_rate, grads) if i % 10 == 0: print(f'Loss step {i}: ', loss_val) # %% [markdown] # # Optimization with Optax # %% # init optax object import optax tx = optax.adam(learning_rate=learning_rate) # we initialize the optimizer with model params opt_state = tx.init(params) loss_grad_fn = jax.value_and_grad(mse) # %% # perform gradient descent with optax optimizer instead of update_params function for i in range(101): loss_val, grads = loss_grad_fn(params, x_samples, y_samples) updates, opt_state = tx.update(grads, opt_state) params = optax.apply_updates(params, updates) if i % 10 == 0: print('Loss step {}: '.format(i), loss_val) # %% [markdown] # Serializing the result - aka saving the model from flax import serialization bytes_output = serialization.to_bytes(params) dict_output = serialization.to_state_dict(params) print('Dict output') print(dict_output) print('Bytes output') print(bytes_output) # %% # load back from dict output # we first get the param structure model = nn.Dense(features=5) key1, key2 = random.split(random.key(0)) # dummy input to trigger shape inference x = random.normal(key1, (10,)) params = model.init(key2, x) # initialization call # then we get the saved weights # from state_dict serialization.from_state_dict(params, dict_output) print(params) # from bytes serialization.from_bytes(params, bytes_output) print(params) # %% [markdown] # Defining your own models # A nn.Module subclass is composed of: # # 1. collection of data fields (argument parameters) # 2. setup() method to setup structures to be called in the forward pass # 3. __call__() function that implements the forward pass # # there is a params structure for the model in the form of a pytree of # parameters # %% # module basics class ExplicitMLP(nn.Module): # model arguments # we have a list of feature sizes for the dense layer features: Sequence[int] def setup(self): self.layers = [nn.Dense(feat) for feat in self.features] def __call__(self, inputs): x = inputs # go through each layer for i, lyr in enumerate(self.layers): x = lyr(x) # add an activating function at the end of each layer if i != len(self.layers) - 1: x = nn.relu(x) return x # init the model key1, key2 = random.split(random.key(117), 2) # any shape is ok x = random.uniform(key1, (4,4)) print(x) # instantiate the model model = ExplicitMLP(features=[3,4,5]) # init model with an rng key and a test data params = model.init(key2, x) y = model.apply(params, x) print( 'initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)), ) print('output:\n', y) # %% # using nn.Module with @nn.compact class SimpleMLP(nn.Module): features: Sequence[int] # nn.compact style is declaring submodules inline @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): x = nn.Dense(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) # providing a name is optional though! # the default autonames would be "Dense_0", "Dense_1", ... return x key1, key2 = random.split(random.key(117), 2) x = random.uniform(key1, (4,4)) model = SimpleMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) # %% # creating your own layers # this is a guide to show you how to create your own dense layer as an example class SimpleDense(nn.Module): features: int kernel_init: Callable = nn.initializers.lecun_normal() bias_init: Callable = nn.initializers.zeros_init() @nn.compact def __call__(self, inputs): # utilize the nn.Module.param function to declare a module # * because params are in __call__, shape inference is possible * # the 'W' kernel = self.param( 'kernel', # name self.kernel_init, # Initialization function (inputs.shape[-1], self.features), # shape ) # 'Wx' y = jnp.dot(inputs, kernel) # 'Wx + b' bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleDense(features=3) # init_fn takes (prng key, *init args, **init kwargs) params = model.init(key2, x) y = model.apply(params, x) print('initialized parameters:\n', params) print('output:\n', y) # %% # handling state in your module # jax does not work well with side-effects and hidden state # jax functions are supposed to be pure functions with determinate input outputs # we introduce "mutable" to separate the function and the state # note: there are 2 collections: # 'params' for trainable, 'batch_stats' for non-trainable class BiasAdderWithRunningMean(nn.Module): decay: float = 0.99 @nn.compact def __call__(self, x): # easy pattern to detect if we're initializing via empty variable tree # check for variable 'mean' in the 'batch_stats' collection is_initialized = self.has_variable('batch_stats', 'mean') # declare variable outside model parameters # declare variable mean in the 'batch_stats' collection # ra_mean now means batch_stats->mean ra_mean = self.variable( 'batch_stats', 'mean', lambda s: jnp.zeros(s), x.shape[1:] # dim of data after 0th element e.g. (10,5) -> 5 ) bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) if is_initialized: ra_mean.value = ( self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True) ) return x - ra_mean.value + bias key1, key2 = random.split(random.key(0), 2) x = jnp.ones((10,5)) model = BiasAdderWithRunningMean() variables = model.init(key1, x) print('initialized variables:\n', variables) # setting mutable means you also get the updated_state of mutable # hence there is no true hidden state in the function # this serves as a escape hatch for implementing pure functions with state # state is therefore kept away from the module y, updated_state = model.apply(variables, x, mutable=['batch_stats']) print('updated state:\n', updated_state) # %% for val in [1.0, 2.0, 3.0]: x = val * jnp.ones((10, 5)) y, updated_state = model.apply(variables, x, mutable=['batch_stats']) old_state, params = flax.core.pop(variables, 'params') variables = flax.core.freeze({'params': params, **updated_state}) print('updated state:\n', updated_state) # Shows only the mutable part # %%