55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
# an example of stateful operations
|
|
# %%
|
|
import equinox as eqx
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.random as jr
|
|
import optax # https://github.com/deepmind/optax
|
|
from equinox.nn import State, StateIndex, StatefulLayer
|
|
from jaxtyping import Array
|
|
|
|
|
|
# %%
|
|
class Counter(eqx.Module):
|
|
# This wraps together (a) a unique dictionary key used for looking up a
|
|
# stateful value, and (b) how that stateful value should be initialised.
|
|
index: eqx.nn.StateIndex
|
|
|
|
def __init__(self):
|
|
init_state = jnp.array(0)
|
|
self.index = eqx.nn.StateIndex(init_state)
|
|
|
|
# eqx.nn.State stores the state of the model
|
|
# This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays.
|
|
# This class should be initialised via equinox.nn.make_with_state.
|
|
#
|
|
# Basically just a dictionary which (a) works only with StateIndex-s, and which (b)
|
|
# works around a JAX bug that prevents flattening dicts with `object()` keys, and which
|
|
# (c) does error-checking that you're using the most up-to-date version of it.
|
|
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
|
|
value = state.get(self.index)
|
|
new_x = x + value
|
|
|
|
# Sets a new value for an [`equinox.nn.StateIndex`][], and returns the
|
|
# updated state.
|
|
new_state = state.set(self.index, value + 1)
|
|
return new_x, new_state
|
|
|
|
# make_with_state is the recommended way to start a stateful model
|
|
counter, state = eqx.nn.make_with_state(Counter)()
|
|
x = jnp.array(2.3)
|
|
|
|
num_calls = state.get(counter.index)
|
|
print(f"Called {num_calls} times.") # 0
|
|
|
|
_, state = counter(x, state)
|
|
num_calls = state.get(counter.index)
|
|
print(f"Called {num_calls} times.") # 1
|
|
|
|
_, state = counter(x, state)
|
|
num_calls = state.get(counter.index)
|
|
print(f"Called {num_calls} times.") # 2
|
|
|
|
|
|
# %%
|