# %% import tensorflow_datasets as tfds # TFDS for MNIST import tensorflow as tf # TensorFlow operations tf.random.set_seed(0) # set random seed for reproducibility num_epochs = 10 batch_size = 32 train_ds: tf.data.Dataset = tfds.load('mnist', split='train') test_ds: tf.data.Dataset = tfds.load('mnist', split='test') train_ds = train_ds.map( lambda sample: { 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } ) # normalize train set test_ds = test_ds.map( lambda sample: { 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } ) # normalize test set # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from train_ds = train_ds.repeat(num_epochs).shuffle(1024) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from test_ds = test_ds.shuffle(1024) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # %% from flax import nnx # NNX API from functools import partial class CNN(nnx.Module): """A simple CNN model.""" def __init__(self, *, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) self.linear1 = nnx.Linear(3136, 256, rngs=rngs) self.linear2 = nnx.Linear(256, 10, rngs=rngs) def __call__(self, x): x = self.avg_pool(nnx.relu(self.conv1(x))) x = self.avg_pool(nnx.relu(self.conv2(x))) x = x.reshape(x.shape[0], -1) # flatten x = nnx.relu(self.linear1(x)) x = self.linear2(x) return x model = CNN(rngs=nnx.Rngs(0)) # %% nnx.display(model) # %% # test the model by feeding an example input import jax.numpy as jnp # JAX NumPy y = model(jnp.ones((1, 28, 28, 1))) nnx.display(y) # %% import optax learning_rate = 0.005 momentum = 0.9 optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) metrics = nnx.MultiMetric( accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'), ) nnx.display(optimizer) # %% def loss_fn(model: CNN, batch): logits = model(batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label'] ).mean() return loss, logits # %% @nnx.jit def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch): """Train for a single step.""" grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(model, batch) metrics.update(loss=loss, logits=logits, labels=batch['label']) optimizer.update(grads) # %% # evaluation step @nnx.jit def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): loss, logits = loss_fn(model, batch) metrics.update(loss=loss, logits=logits, labels=batch['label']) # %% # for dataset seed random generation tf.random.set_seed(0) # %% num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs metrics_history = { 'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': [], } for step, batch in enumerate(train_ds.as_numpy_iterator()): # Run the optimization for one step and make a stateful update to the following: # - the train state's model parameters # - the optimizer state # - the training loss and accuracy batch metrics train_step(model, optimizer, metrics, batch) if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed # Log training metrics for metric, value in metrics.compute().items(): # compute metrics metrics_history[f'train_{metric}'].append(value) # record metrics metrics.reset() # reset metrics for test set # Compute metrics on the test set after each training epoch for test_batch in test_ds.as_numpy_iterator(): eval_step(model, metrics, test_batch) # Log test metrics for metric, value in metrics.compute().items(): metrics_history[f'test_{metric}'].append(value) metrics.reset() # reset metrics for next training epoch print( f"train epoch: {(step+1) // num_steps_per_epoch}, " f"loss: {metrics_history['train_loss'][-1]}, " f"accuracy: {metrics_history['train_accuracy'][-1] * 100}" ) print( f"test epoch: {(step+1) // num_steps_per_epoch}, " f"loss: {metrics_history['test_loss'][-1]}, " f"accuracy: {metrics_history['test_accuracy'][-1] * 100}" ) # %% # visualize metrics import matplotlib.pyplot as plt # Visualization # Plot loss and accuracy in subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) ax1.set_title('Loss') ax2.set_title('Accuracy') for dataset in ('train', 'test'): ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') ax1.legend() ax2.legend() plt.show() # %%