# %% # introduction to how flax does stateful operations import flax.linen as nn import jax.numpy as jnp import jax import flax from jaxtyping import Array # %% class BiasAdderWithRunningMean(nn.Module): momentum: float = 0.9 @nn.compact def __call__(self, x): is_initialized = self.has_variable('hehe', 'mean') print(is_initialized) mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:]) bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) if is_initialized: print(mean.value) # notice that value retains after first call mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean( x, axis=0, keepdims=True ) print(mean.value) return mean.value + bias # %% input_key = jax.random.PRNGKey(0) model = BiasAdderWithRunningMean() inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values variables = model.init(input_key, inputs) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') print(f"first init: {state}") # %% for i in range(5): new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs # notice how we are threading the state # perform argument unpacking on state dictionary output, state = model.apply({'params': params, **state}, new_inputs, mutable=list(state.keys())) # mean_state = variables['batch_stats']['mean'] # Access the updated mean state print(f"updated state {state}") print(f"Output after input {i + 1}: {output}") # print(f"Updated running mean state: {mean_state}") # %% ########################################################### # example 2 from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal import jax.random as random class Foo(nn.Module): features: int @nn.compact def __call__(self): key = self.make_rng('spectral_norm_stats') print(key) u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features)) return u0_variable.value foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)}) Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)}) # --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32) # %% model = Foo(3) # %% # state is kept in self.variable, tied to the layer output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)}) # %% output, state = model.apply( foovars, mutable=list(foovars.keys()), rngs={'spectral_norm_stats': random.PRNGKey(1)} ) print(output, state) # %% output, state = model.apply( state, mutable=list(foovars.keys()), rngs={'spectral_norm_stats': random.PRNGKey(1)} ) # no change because input state is the same print(output, state) # %% state_array = state['spectral_norm_stats']['u0'] modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0)) state['spectral_norm_stats']['u0'] = modified_array # %% # %% output, state = model.apply( state, mutable=list(foovars.keys()), rngs={'spectral_norm_stats': random.PRNGKey(1)} ) # state takes from given state # note the modified 0.9 value # note how the state is not re-initialized print(output, state) # %%