107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
# %%
|
|
# introduction to how flax does stateful operations
|
|
import flax.linen as nn
|
|
import jax.numpy as jnp
|
|
import jax
|
|
import flax
|
|
from jaxtyping import Array
|
|
|
|
# %%
|
|
|
|
class BiasAdderWithRunningMean(nn.Module):
|
|
momentum: float = 0.9
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
is_initialized = self.has_variable('hehe', 'mean')
|
|
print(is_initialized)
|
|
mean = self.variable('hehe', 'mean', jnp.zeros, x.shape[1:])
|
|
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
|
|
if is_initialized:
|
|
print(mean.value) # notice that value retains after first call
|
|
mean.value = self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean(
|
|
x, axis=0, keepdims=True
|
|
)
|
|
print(mean.value)
|
|
return mean.value + bias
|
|
|
|
|
|
# %%
|
|
|
|
input_key = jax.random.PRNGKey(0)
|
|
model = BiasAdderWithRunningMean()
|
|
inputs = jax.random.normal(input_key, (10, 5)) # Generate random normal values
|
|
variables = model.init(input_key, inputs)
|
|
# Split state and params (which are updated by optimizer).
|
|
state, params = flax.core.pop(variables, 'params')
|
|
print(f"first init: {state}")
|
|
# %%
|
|
for i in range(5):
|
|
new_inputs = jax.random.normal(jax.random.PRNGKey(i + 1), (10,5)) # New random inputs
|
|
# notice how we are threading the state
|
|
# perform argument unpacking on state dictionary
|
|
output, state = model.apply({'params': params, **state},
|
|
new_inputs, mutable=list(state.keys()))
|
|
|
|
# mean_state = variables['batch_stats']['mean'] # Access the updated mean state
|
|
print(f"updated state {state}")
|
|
print(f"Output after input {i + 1}: {output}")
|
|
# print(f"Updated running mean state: {mean_state}")
|
|
# %%
|
|
###########################################################
|
|
# example 2
|
|
from flax.linen.initializers import lecun_normal, variance_scaling, zeros, normal
|
|
import jax.random as random
|
|
class Foo(nn.Module):
|
|
features: int
|
|
@nn.compact
|
|
def __call__(self):
|
|
key = self.make_rng('spectral_norm_stats')
|
|
print(key)
|
|
u0_variable = self.variable('spectral_norm_stats', 'u0', normal(), key, (1, self.features))
|
|
return u0_variable.value
|
|
|
|
foovars = Foo(3).init({'params': random.PRNGKey(0), 'spectral_norm_stats': random.PRNGKey(1)})
|
|
Foo(3).apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
|
|
# --> DeviceArray([[0.00711277, 0.0107195 , 0.019903 ]], dtype=float32)
|
|
|
|
# %%
|
|
model = Foo(3)
|
|
|
|
# %%
|
|
# state is kept in self.variable, tied to the layer
|
|
output = model.apply(foovars, rngs={'spectral_norm_stats': random.PRNGKey(1)})
|
|
|
|
# %%
|
|
output, state = model.apply(
|
|
foovars,
|
|
mutable=list(foovars.keys()),
|
|
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
|
)
|
|
print(output, state)
|
|
# %%
|
|
output, state = model.apply(
|
|
state,
|
|
mutable=list(foovars.keys()),
|
|
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
|
)
|
|
# no change because input state is the same
|
|
print(output, state)
|
|
# %%
|
|
state_array = state['spectral_norm_stats']['u0']
|
|
modified_array = jax.lax.dynamic_update_slice(state_array, jnp.array([[0.9]]), (0,0))
|
|
state['spectral_norm_stats']['u0'] = modified_array
|
|
# %%
|
|
# %%
|
|
output, state = model.apply(
|
|
state,
|
|
mutable=list(foovars.keys()),
|
|
rngs={'spectral_norm_stats': random.PRNGKey(1)}
|
|
)
|
|
# state takes from given state
|
|
# note the modified 0.9 value
|
|
# note how the state is not re-initialized
|
|
print(output, state)
|
|
|
|
# %%
|