# %% import equinox as eqx import jax import jax.numpy as jnp import optax # https://github.com/deepmind/optax import torch # https://pytorch.org import torchvision # https://pytorch.org from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping # %% # Hyperparameters BATCH_SIZE = 64 LEARNING_RATE = 3e-4 STEPS = 300 PRINT_EVERY = 30 SEED = 5678 key = jax.random.PRNGKey(SEED) # %% normalise_data = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,)), ] ) train_dataset = torchvision.datasets.MNIST( "MNIST", train=True, download=True, transform=normalise_data, ) test_dataset = torchvision.datasets.MNIST( "MNIST", train=False, download=True, transform=normalise_data, ) trainloader = torch.utils.data.DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True ) testloader = torch.utils.data.DataLoader( test_dataset, batch_size=BATCH_SIZE, shuffle=True ) # %% # Checking our data a bit (by now, everyone knows what the MNIST dataset looks like) dummy_x, dummy_y = next(iter(trainloader)) dummy_x = dummy_x.numpy() dummy_y = dummy_y.numpy() print(dummy_x.shape) # 64x1x28x28 print(dummy_y.shape) # 64 print(dummy_y) # %% class CNN(eqx.Module): layers: list def __init__(self, key): key1, key2, key3, key4 = jax.random.split(key, 4) # Standard CNN setup: convolutional layer, followed by flattening, # with a small MLP on top. self.layers = [ eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1), eqx.nn.MaxPool2d(kernel_size=2), jax.nn.relu, # jax functions!!! jnp.ravel, eqx.nn.Linear(1728, 512, key=key2), jax.nn.sigmoid, eqx.nn.Linear(512, 64, key=key3), jax.nn.relu, eqx.nn.Linear(64, 10, key=key4), jax.nn.log_softmax, ] def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: for layer in self.layers: x = layer(x) return x key, subkey = jax.random.split(key, 2) model = CNN(subkey) # %% print(model) # %% # print the first layer: Conv2d print(model.layers[0]) # %% # illustrated inference def loss( model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] ) -> Float[Array, ""]: # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on # a single input input image of shape (1, 28, 28). # # Therefore, we have to use jax.vmap, which in this case maps our model over the # leading (batch) axis. # # This is an example of writing function for one input, then letting jax # automatically vectorize over the batch dimension pred_y = jax.vmap(model)(x) return cross_entropy(y, pred_y) def cross_entropy( y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"] ) -> Float[Array, ""]: # y are the true targets, and should be integers 0-9. # pred_y are the log-softmax'd predictions. # take_along_axis: take from pred_y along axis 1 according to 2nd argument # expand_dims to axis 1 makes it of shape (y_dim, 1) # since we take along axis 1, each y (in 2nd arg) therefore takes the # corresponding entry of each row in pred_y pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) return -jnp.mean(pred_y) # negative mean of relevant logits # Example loss loss_value = loss(model, dummy_x, dummy_y) print(loss_value) print(loss_value.shape) # scalar loss # Example inference output = jax.vmap(model)(dummy_x) print(output.shape) # batch of predictions # %% # This is an error! # the reason is that model has to be parameters, but model has non-parameters jax.value_and_grad(loss)(model, dummy_x, dummy_y) # %% # we separate out things that are params from other things # since params are things that are arrays # partition is doing filter(...) and filter(..., inverse=True) params, static = eqx.partition(model, eqx.is_array) # %% # lets compare the same object in both terms print(static.layers[0]) print(params.layers[0]) # %% # in the loss, we recombine both to form back our model def loss2(params, static, x, y): model = eqx.combine(params, static) return loss(model, x, y) # Now this will work! # since the grad only looks at the first argument, this works out loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y) print(loss_value) # %% # This will work too! # this works the same as the previous value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y) print(value) # %% # evaluation loss = eqx.filter_jit(loss) # JIT our loss function from earlier! @eqx.filter_jit def compute_accuracy( model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] ) -> Float[Array, ""]: """This function takes as input the current model and computes the average accuracy on a batch. """ pred_y = jax.vmap(model)(x) pred_y = jnp.argmax(pred_y, axis=1) return jnp.mean(y == pred_y) # %% def evaluate(model: CNN, testloader: torch.utils.data.DataLoader): """This function evaluates the model on the test dataset, computing both the average loss and the average accuracy. """ avg_loss = 0 avg_acc = 0 for x, y in testloader: x = x.numpy() y = y.numpy() # Note that all the JAX operations happen inside `loss` and `compute_accuracy`, # and both have JIT wrappers, so this is fast. avg_loss += loss(model, x, y) avg_acc += compute_accuracy(model, x, y) return avg_loss / len(testloader), avg_acc / len(testloader) # %% evaluate(model, testloader) # %% # training optim = optax.adamw(LEARNING_RATE) def train( model: CNN, trainloader: torch.utils.data.DataLoader, testloader: torch.utils.data.DataLoader, optim: optax.GradientTransformation, steps: int, print_every: int, ) -> CNN: # Just like earlier: It only makes sense to train the arrays in our model, # so filter out everything else. opt_state = optim.init(eqx.filter(model, eqx.is_array)) # Always wrap everything -- computing gradients, running the optimiser, updating # the model -- into a single JIT region. This ensures things run as fast as # possible. @eqx.filter_jit def make_step( model: CNN, opt_state: PyTree, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"], ): loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) updates, opt_state = optim.update(grads, opt_state, model) model = eqx.apply_updates(model, updates) return model, opt_state, loss_value # Loop over our training dataset as many times as we need. def infinite_trainloader(): while True: yield from trainloader for step, (x, y) in zip(range(steps), infinite_trainloader()): # PyTorch dataloaders give PyTorch tensors by default, # so convert them to NumPy arrays. x = x.numpy() y = y.numpy() model, opt_state, train_loss = make_step(model, opt_state, x, y) if (step % print_every) == 0 or (step == steps - 1): test_loss, test_accuracy = evaluate(model, testloader) print( f"{step=}, train_loss={train_loss.item()}, " f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" ) return model # %% model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY) # %%