169 lines
5.3 KiB
Python
169 lines
5.3 KiB
Python
# %%
|
|
import tensorflow_datasets as tfds # TFDS for MNIST
|
|
import tensorflow as tf # TensorFlow operations
|
|
|
|
tf.random.set_seed(0) # set random seed for reproducibility
|
|
|
|
num_epochs = 10
|
|
batch_size = 32
|
|
|
|
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
|
|
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
|
|
|
|
train_ds = train_ds.map(
|
|
lambda sample: {
|
|
'image': tf.cast(sample['image'], tf.float32) / 255,
|
|
'label': sample['label'],
|
|
}
|
|
) # normalize train set
|
|
test_ds = test_ds.map(
|
|
lambda sample: {
|
|
'image': tf.cast(sample['image'], tf.float32) / 255,
|
|
'label': sample['label'],
|
|
}
|
|
) # normalize test set
|
|
|
|
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
|
|
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
|
|
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
|
|
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
|
|
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
|
|
test_ds = test_ds.shuffle(1024)
|
|
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
|
|
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
|
|
# %%
|
|
from flax import nnx # NNX API
|
|
from functools import partial
|
|
|
|
class CNN(nnx.Module):
|
|
"""A simple CNN model."""
|
|
|
|
def __init__(self, *, rngs: nnx.Rngs):
|
|
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
|
|
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
|
|
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
|
|
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
|
|
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
|
|
|
|
def __call__(self, x):
|
|
x = self.avg_pool(nnx.relu(self.conv1(x)))
|
|
x = self.avg_pool(nnx.relu(self.conv2(x)))
|
|
x = x.reshape(x.shape[0], -1) # flatten
|
|
x = nnx.relu(self.linear1(x))
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
|
|
model = CNN(rngs=nnx.Rngs(0))
|
|
|
|
# %%
|
|
nnx.display(model)
|
|
# %%
|
|
# test the model by feeding an example input
|
|
import jax.numpy as jnp # JAX NumPy
|
|
|
|
y = model(jnp.ones((1, 28, 28, 1)))
|
|
nnx.display(y)
|
|
# %%
|
|
import optax
|
|
|
|
learning_rate = 0.005
|
|
momentum = 0.9
|
|
|
|
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
|
|
metrics = nnx.MultiMetric(
|
|
accuracy=nnx.metrics.Accuracy(),
|
|
loss=nnx.metrics.Average('loss'),
|
|
)
|
|
|
|
nnx.display(optimizer)
|
|
|
|
# %%
|
|
|
|
def loss_fn(model: CNN, batch):
|
|
logits = model(batch['image'])
|
|
loss = optax.softmax_cross_entropy_with_integer_labels(
|
|
logits=logits, labels=batch['label']
|
|
).mean()
|
|
return loss, logits
|
|
|
|
# %%
|
|
@nnx.jit
|
|
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
|
|
"""Train for a single step."""
|
|
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
|
|
(loss, logits), grads = grad_fn(model, batch)
|
|
metrics.update(loss=loss, logits=logits, labels=batch['label'])
|
|
optimizer.update(grads)
|
|
|
|
# %%
|
|
# evaluation step
|
|
@nnx.jit
|
|
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
|
|
loss, logits = loss_fn(model, batch)
|
|
metrics.update(loss=loss, logits=logits, labels=batch['label'])
|
|
|
|
|
|
# %%
|
|
# for dataset seed random generation
|
|
tf.random.set_seed(0)
|
|
|
|
# %%
|
|
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
|
|
|
|
metrics_history = {
|
|
'train_loss': [],
|
|
'train_accuracy': [],
|
|
'test_loss': [],
|
|
'test_accuracy': [],
|
|
}
|
|
|
|
for step, batch in enumerate(train_ds.as_numpy_iterator()):
|
|
# Run the optimization for one step and make a stateful update to the following:
|
|
# - the train state's model parameters
|
|
# - the optimizer state
|
|
# - the training loss and accuracy batch metrics
|
|
train_step(model, optimizer, metrics, batch)
|
|
|
|
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
|
|
# Log training metrics
|
|
for metric, value in metrics.compute().items(): # compute metrics
|
|
metrics_history[f'train_{metric}'].append(value) # record metrics
|
|
metrics.reset() # reset metrics for test set
|
|
|
|
# Compute metrics on the test set after each training epoch
|
|
for test_batch in test_ds.as_numpy_iterator():
|
|
eval_step(model, metrics, test_batch)
|
|
|
|
# Log test metrics
|
|
for metric, value in metrics.compute().items():
|
|
metrics_history[f'test_{metric}'].append(value)
|
|
metrics.reset() # reset metrics for next training epoch
|
|
|
|
print(
|
|
f"train epoch: {(step+1) // num_steps_per_epoch}, "
|
|
f"loss: {metrics_history['train_loss'][-1]}, "
|
|
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
|
|
)
|
|
print(
|
|
f"test epoch: {(step+1) // num_steps_per_epoch}, "
|
|
f"loss: {metrics_history['test_loss'][-1]}, "
|
|
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
|
|
)
|
|
# %%
|
|
# visualize metrics
|
|
import matplotlib.pyplot as plt # Visualization
|
|
|
|
# Plot loss and accuracy in subplots
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
|
ax1.set_title('Loss')
|
|
ax2.set_title('Accuracy')
|
|
for dataset in ('train', 'test'):
|
|
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
|
|
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
|
|
ax1.legend()
|
|
ax2.legend()
|
|
plt.show()
|
|
|
|
# %%
|