learn_jax/equinox/mnist.py

253 lines
7.4 KiB
Python
Raw Permalink Normal View History

# %%
import equinox as eqx
import jax
import jax.numpy as jnp
import optax # https://github.com/deepmind/optax
import torch # https://pytorch.org
import torchvision # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping
# %%
# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678
key = jax.random.PRNGKey(SEED)
# %%
normalise_data = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
]
)
train_dataset = torchvision.datasets.MNIST(
"MNIST",
train=True,
download=True,
transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
"MNIST",
train=False,
download=True,
transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=True
)
# %%
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape) # 64x1x28x28
print(dummy_y.shape) # 64
print(dummy_y)
# %%
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu, # jax functions!!!
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
key, subkey = jax.random.split(key, 2)
model = CNN(subkey)
# %%
print(model)
# %%
# print the first layer: Conv2d
print(model.layers[0])
# %%
# illustrated inference
def loss(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
# a single input input image of shape (1, 28, 28).
#
# Therefore, we have to use jax.vmap, which in this case maps our model over the
# leading (batch) axis.
#
# This is an example of writing function for one input, then letting jax
# automatically vectorize over the batch dimension
pred_y = jax.vmap(model)(x)
return cross_entropy(y, pred_y)
def cross_entropy(
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
# y are the true targets, and should be integers 0-9.
# pred_y are the log-softmax'd predictions.
# take_along_axis: take from pred_y along axis 1 according to 2nd argument
# expand_dims to axis 1 makes it of shape (y_dim, 1)
# since we take along axis 1, each y (in 2nd arg) therefore takes the
# corresponding entry of each row in pred_y
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y) # negative mean of relevant logits
# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value)
print(loss_value.shape) # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape) # batch of predictions
# %%
# This is an error!
# the reason is that model has to be parameters, but model has non-parameters
jax.value_and_grad(loss)(model, dummy_x, dummy_y)
# %%
# we separate out things that are params from other things
# since params are things that are arrays
# partition is doing filter(...) and filter(..., inverse=True)
params, static = eqx.partition(model, eqx.is_array)
# %%
# lets compare the same object in both terms
print(static.layers[0])
print(params.layers[0])
# %%
# in the loss, we recombine both to form back our model
def loss2(params, static, x, y):
model = eqx.combine(params, static)
return loss(model, x, y)
# Now this will work!
# since the grad only looks at the first argument, this works out
loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y)
print(loss_value)
# %%
# This will work too!
# this works the same as the previous
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)
# %%
# evaluation
loss = eqx.filter_jit(loss) # JIT our loss function from earlier!
@eqx.filter_jit
def compute_accuracy(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
"""This function takes as input the current model
and computes the average accuracy on a batch.
"""
pred_y = jax.vmap(model)(x)
pred_y = jnp.argmax(pred_y, axis=1)
return jnp.mean(y == pred_y)
# %%
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
"""This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy.
"""
avg_loss = 0
avg_acc = 0
for x, y in testloader:
x = x.numpy()
y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast.
avg_loss += loss(model, x, y)
avg_acc += compute_accuracy(model, x, y)
return avg_loss / len(testloader), avg_acc / len(testloader)
# %%
evaluate(model, testloader)
# %%
# training
optim = optax.adamw(LEARNING_RATE)
def train(
model: CNN,
trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> CNN:
# Just like earlier: It only makes sense to train the arrays in our model,
# so filter out everything else.
opt_state = optim.init(eqx.filter(model, eqx.is_array))
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
# possible.
@eqx.filter_jit
def make_step(
model: CNN,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
# Loop over our training dataset as many times as we need.
def infinite_trainloader():
while True:
yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()):
# PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays.
x = x.numpy()
y = y.numpy()
model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(model, testloader)
print(
f"{step=}, train_loss={train_loss.item()}, "
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
)
return model
# %%
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
# %%