learn_jax/equinox/handling_state_equinox.py

55 lines
1.8 KiB
Python

# an example of stateful operations
# %%
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # https://github.com/deepmind/optax
from equinox.nn import State, StateIndex, StatefulLayer
from jaxtyping import Array
# %%
class Counter(eqx.Module):
# This wraps together (a) a unique dictionary key used for looking up a
# stateful value, and (b) how that stateful value should be initialised.
index: eqx.nn.StateIndex
def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)
# eqx.nn.State stores the state of the model
# This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays.
# This class should be initialised via equinox.nn.make_with_state.
#
# Basically just a dictionary which (a) works only with StateIndex-s, and which (b)
# works around a JAX bug that prevents flattening dicts with `object()` keys, and which
# (c) does error-checking that you're using the most up-to-date version of it.
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
value = state.get(self.index)
new_x = x + value
# Sets a new value for an [`equinox.nn.StateIndex`][], and returns the
# updated state.
new_state = state.set(self.index, value + 1)
return new_x, new_state
# make_with_state is the recommended way to start a stateful model
counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 0
_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 1
_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 2
# %%