253 lines
7.4 KiB
Python
253 lines
7.4 KiB
Python
|
# %%
|
||
|
import equinox as eqx
|
||
|
import jax
|
||
|
import jax.numpy as jnp
|
||
|
import optax # https://github.com/deepmind/optax
|
||
|
import torch # https://pytorch.org
|
||
|
import torchvision # https://pytorch.org
|
||
|
from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping
|
||
|
# %%
|
||
|
# Hyperparameters
|
||
|
|
||
|
BATCH_SIZE = 64
|
||
|
LEARNING_RATE = 3e-4
|
||
|
STEPS = 300
|
||
|
PRINT_EVERY = 30
|
||
|
SEED = 5678
|
||
|
|
||
|
key = jax.random.PRNGKey(SEED)
|
||
|
|
||
|
# %%
|
||
|
normalise_data = torchvision.transforms.Compose(
|
||
|
[
|
||
|
torchvision.transforms.ToTensor(),
|
||
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||
|
]
|
||
|
)
|
||
|
train_dataset = torchvision.datasets.MNIST(
|
||
|
"MNIST",
|
||
|
train=True,
|
||
|
download=True,
|
||
|
transform=normalise_data,
|
||
|
)
|
||
|
test_dataset = torchvision.datasets.MNIST(
|
||
|
"MNIST",
|
||
|
train=False,
|
||
|
download=True,
|
||
|
transform=normalise_data,
|
||
|
)
|
||
|
trainloader = torch.utils.data.DataLoader(
|
||
|
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||
|
)
|
||
|
testloader = torch.utils.data.DataLoader(
|
||
|
test_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||
|
)
|
||
|
|
||
|
# %%
|
||
|
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
|
||
|
dummy_x, dummy_y = next(iter(trainloader))
|
||
|
dummy_x = dummy_x.numpy()
|
||
|
dummy_y = dummy_y.numpy()
|
||
|
print(dummy_x.shape) # 64x1x28x28
|
||
|
print(dummy_y.shape) # 64
|
||
|
print(dummy_y)
|
||
|
|
||
|
# %%
|
||
|
class CNN(eqx.Module):
|
||
|
layers: list
|
||
|
|
||
|
def __init__(self, key):
|
||
|
key1, key2, key3, key4 = jax.random.split(key, 4)
|
||
|
# Standard CNN setup: convolutional layer, followed by flattening,
|
||
|
# with a small MLP on top.
|
||
|
self.layers = [
|
||
|
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||
|
eqx.nn.MaxPool2d(kernel_size=2),
|
||
|
jax.nn.relu, # jax functions!!!
|
||
|
jnp.ravel,
|
||
|
eqx.nn.Linear(1728, 512, key=key2),
|
||
|
jax.nn.sigmoid,
|
||
|
eqx.nn.Linear(512, 64, key=key3),
|
||
|
jax.nn.relu,
|
||
|
eqx.nn.Linear(64, 10, key=key4),
|
||
|
jax.nn.log_softmax,
|
||
|
]
|
||
|
|
||
|
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||
|
for layer in self.layers:
|
||
|
x = layer(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
key, subkey = jax.random.split(key, 2)
|
||
|
model = CNN(subkey)
|
||
|
# %%
|
||
|
print(model)
|
||
|
# %%
|
||
|
# print the first layer: Conv2d
|
||
|
print(model.layers[0])
|
||
|
# %%
|
||
|
# illustrated inference
|
||
|
def loss(
|
||
|
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||
|
) -> Float[Array, ""]:
|
||
|
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
|
||
|
# a single input input image of shape (1, 28, 28).
|
||
|
#
|
||
|
# Therefore, we have to use jax.vmap, which in this case maps our model over the
|
||
|
# leading (batch) axis.
|
||
|
#
|
||
|
# This is an example of writing function for one input, then letting jax
|
||
|
# automatically vectorize over the batch dimension
|
||
|
pred_y = jax.vmap(model)(x)
|
||
|
return cross_entropy(y, pred_y)
|
||
|
|
||
|
|
||
|
def cross_entropy(
|
||
|
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
|
||
|
) -> Float[Array, ""]:
|
||
|
# y are the true targets, and should be integers 0-9.
|
||
|
# pred_y are the log-softmax'd predictions.
|
||
|
|
||
|
# take_along_axis: take from pred_y along axis 1 according to 2nd argument
|
||
|
# expand_dims to axis 1 makes it of shape (y_dim, 1)
|
||
|
# since we take along axis 1, each y (in 2nd arg) therefore takes the
|
||
|
# corresponding entry of each row in pred_y
|
||
|
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||
|
return -jnp.mean(pred_y) # negative mean of relevant logits
|
||
|
|
||
|
|
||
|
# Example loss
|
||
|
loss_value = loss(model, dummy_x, dummy_y)
|
||
|
print(loss_value)
|
||
|
print(loss_value.shape) # scalar loss
|
||
|
# Example inference
|
||
|
output = jax.vmap(model)(dummy_x)
|
||
|
print(output.shape) # batch of predictions
|
||
|
|
||
|
# %%
|
||
|
# This is an error!
|
||
|
# the reason is that model has to be parameters, but model has non-parameters
|
||
|
jax.value_and_grad(loss)(model, dummy_x, dummy_y)
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# we separate out things that are params from other things
|
||
|
# since params are things that are arrays
|
||
|
# partition is doing filter(...) and filter(..., inverse=True)
|
||
|
params, static = eqx.partition(model, eqx.is_array)
|
||
|
|
||
|
# %%
|
||
|
# lets compare the same object in both terms
|
||
|
print(static.layers[0])
|
||
|
print(params.layers[0])
|
||
|
|
||
|
|
||
|
# %%
|
||
|
# in the loss, we recombine both to form back our model
|
||
|
def loss2(params, static, x, y):
|
||
|
model = eqx.combine(params, static)
|
||
|
return loss(model, x, y)
|
||
|
|
||
|
|
||
|
# Now this will work!
|
||
|
# since the grad only looks at the first argument, this works out
|
||
|
loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y)
|
||
|
print(loss_value)
|
||
|
|
||
|
# %%
|
||
|
# This will work too!
|
||
|
# this works the same as the previous
|
||
|
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
|
||
|
print(value)
|
||
|
|
||
|
# %%
|
||
|
# evaluation
|
||
|
loss = eqx.filter_jit(loss) # JIT our loss function from earlier!
|
||
|
|
||
|
|
||
|
@eqx.filter_jit
|
||
|
def compute_accuracy(
|
||
|
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||
|
) -> Float[Array, ""]:
|
||
|
"""This function takes as input the current model
|
||
|
and computes the average accuracy on a batch.
|
||
|
"""
|
||
|
pred_y = jax.vmap(model)(x)
|
||
|
pred_y = jnp.argmax(pred_y, axis=1)
|
||
|
return jnp.mean(y == pred_y)
|
||
|
|
||
|
# %%
|
||
|
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
|
||
|
"""This function evaluates the model on the test dataset,
|
||
|
computing both the average loss and the average accuracy.
|
||
|
"""
|
||
|
avg_loss = 0
|
||
|
avg_acc = 0
|
||
|
for x, y in testloader:
|
||
|
x = x.numpy()
|
||
|
y = y.numpy()
|
||
|
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||
|
# and both have JIT wrappers, so this is fast.
|
||
|
avg_loss += loss(model, x, y)
|
||
|
avg_acc += compute_accuracy(model, x, y)
|
||
|
return avg_loss / len(testloader), avg_acc / len(testloader)
|
||
|
|
||
|
# %%
|
||
|
evaluate(model, testloader)
|
||
|
# %%
|
||
|
# training
|
||
|
optim = optax.adamw(LEARNING_RATE)
|
||
|
|
||
|
def train(
|
||
|
model: CNN,
|
||
|
trainloader: torch.utils.data.DataLoader,
|
||
|
testloader: torch.utils.data.DataLoader,
|
||
|
optim: optax.GradientTransformation,
|
||
|
steps: int,
|
||
|
print_every: int,
|
||
|
) -> CNN:
|
||
|
# Just like earlier: It only makes sense to train the arrays in our model,
|
||
|
# so filter out everything else.
|
||
|
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||
|
|
||
|
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||
|
# the model -- into a single JIT region. This ensures things run as fast as
|
||
|
# possible.
|
||
|
@eqx.filter_jit
|
||
|
def make_step(
|
||
|
model: CNN,
|
||
|
opt_state: PyTree,
|
||
|
x: Float[Array, "batch 1 28 28"],
|
||
|
y: Int[Array, " batch"],
|
||
|
):
|
||
|
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
||
|
updates, opt_state = optim.update(grads, opt_state, model)
|
||
|
model = eqx.apply_updates(model, updates)
|
||
|
return model, opt_state, loss_value
|
||
|
|
||
|
# Loop over our training dataset as many times as we need.
|
||
|
def infinite_trainloader():
|
||
|
while True:
|
||
|
yield from trainloader
|
||
|
|
||
|
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||
|
# so convert them to NumPy arrays.
|
||
|
x = x.numpy()
|
||
|
y = y.numpy()
|
||
|
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||
|
if (step % print_every) == 0 or (step == steps - 1):
|
||
|
test_loss, test_accuracy = evaluate(model, testloader)
|
||
|
print(
|
||
|
f"{step=}, train_loss={train_loss.item()}, "
|
||
|
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||
|
)
|
||
|
return model
|
||
|
|
||
|
|
||
|
# %%
|
||
|
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
|
||
|
|
||
|
# %%
|