# an example of stateful operations # %% import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import optax # https://github.com/deepmind/optax from equinox.nn import State, StateIndex, StatefulLayer from jaxtyping import Array # %% class Counter(eqx.Module): # This wraps together (a) a unique dictionary key used for looking up a # stateful value, and (b) how that stateful value should be initialised. index: eqx.nn.StateIndex def __init__(self): init_state = jnp.array(0) self.index = eqx.nn.StateIndex(init_state) # eqx.nn.State stores the state of the model # This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays. # This class should be initialised via equinox.nn.make_with_state. # # Basically just a dictionary which (a) works only with StateIndex-s, and which (b) # works around a JAX bug that prevents flattening dicts with `object()` keys, and which # (c) does error-checking that you're using the most up-to-date version of it. def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]: value = state.get(self.index) new_x = x + value # Sets a new value for an [`equinox.nn.StateIndex`][], and returns the # updated state. new_state = state.set(self.index, value + 1) return new_x, new_state # make_with_state is the recommended way to start a stateful model counter, state = eqx.nn.make_with_state(Counter)() x = jnp.array(2.3) num_calls = state.get(counter.index) print(f"Called {num_calls} times.") # 0 _, state = counter(x, state) num_calls = state.get(counter.index) print(f"Called {num_calls} times.") # 1 _, state = counter(x, state) num_calls = state.get(counter.index) print(f"Called {num_calls} times.") # 2 # %%