learn_jax/equinox/handling_state_flax.py

107 lines
3.4 KiB
Python

# %%
# introduction to how flax does stateful operations
import flax.linen as nn
import jax.numpy as jnp
import jax
import flax
from jaxtyping import Array
# %%
class BiasAdderWithRunningMean(nn.Module):
momentum: float = 0.9
@nn.compact
def __call__(self, x):
is_initialized = self.has_variable('hehe', 'mean')
print(is_initialized)
mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:])
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
print(mean.value) # notice that value retains after first call
mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean(
x, axis=0, keepdims=True
)
print(mean.value)
return mean.value + bias
# %%
input_key = jax.random.PRNGKey(0)
model = BiasAdderWithRunningMean()
inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values
variables = model.init(input_key, inputs)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
print(f"first init: {state}")
# %%
for i in range(5):
new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs
# notice how we are threading the state
# perform argument unpacking on state dictionary
output, state = model.apply({'params': params, **state},
new_inputs, mutable=list(state.keys()))
# mean_state = variables['batch_stats']['mean'] # Access the updated mean state
print(f"updated state {state}")
print(f"Output after input {i + 1}: {output}")
# print(f"Updated running mean state: {mean_state}")
# %%
###########################################################
# example 2
from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal
import jax.random as random
class Foo(nn.Module):
features: int
@nn.compact
def __call__(self):
key = self.make_rng('spectral_norm_stats')
print(key)
u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features))
return u0_variable.value
foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)})
Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
# --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32)
# %%
model = Foo(3)
# %%
# state is kept in self.variable, tied to the layer
output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
# %%
output, state = model.apply(
foovars,
mutable=list(foovars.keys()),
rngs={'spectral_norm_stats': random.PRNGKey(1)}
)
print(output, state)
# %%
output, state = model.apply(
state,
mutable=list(foovars.keys()),
rngs={'spectral_norm_stats': random.PRNGKey(1)}
)
# no change because input state is the same
print(output, state)
# %%
state_array = state['spectral_norm_stats']['u0']
modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0))
state['spectral_norm_stats']['u0'] = modified_array
# %%
# %%
output, state = model.apply(
state,
mutable=list(foovars.keys()),
rngs={'spectral_norm_stats': random.PRNGKey(1)}
)
# state takes from given state
# note the modified 0.9 value
# note how the state is not re-initialized
print(output, state)
# %%