Feat: learn flax
This commit is contained in:
		
							parent
							
								
									d2dd72227f
								
							
						
					
					
						commit
						005a1a5735
					
				|  | @ -3,3 +3,5 @@ t5_*/ | |||
| exports/ | ||||
| modified_t5_model/ | ||||
| traces/ | ||||
| ruff.toml | ||||
| settings.json | ||||
|  |  | |||
|  | @ -0,0 +1,332 @@ | |||
| # --- | ||||
| # jupyter: | ||||
| #   jupytext: | ||||
| #     formats: ipynb,py:percent | ||||
| #     text_representation: | ||||
| #       extension: .py | ||||
| #       format_name: percent | ||||
| #       format_version: '1.3' | ||||
| #       jupytext_version: 1.16.4 | ||||
| # --- | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Flax Basics | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Linear regression with Flax | ||||
| # In this example we perform linear regression with gradient descent. It is | ||||
| # done by first coming up with a model, generating data from the model (because | ||||
| # data generated from known parameters will "match" the model predictions | ||||
| # perfectly albeit with a little noise). | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # import packages | ||||
| import jax | ||||
| from typing import Any, Callable, Sequence | ||||
| from jax import random, numpy as jnp | ||||
| import flax | ||||
| from flax import linen as nn | ||||
| 
 | ||||
| # %% | ||||
| # define model | ||||
| # notice there is no input size declaration | ||||
| model = nn.Dense(features=5) | ||||
| 
 | ||||
| # model paramaters & initialization | ||||
| key1, key2 = random.split(random.key(0)) | ||||
| # dummy input to trigger shape inference | ||||
| x = random.normal(key1, (10,)) | ||||
| params = model.init(key2, x)  # initialization call | ||||
| 
 | ||||
| # check output shapes | ||||
| print(jax.tree_util.tree_map(lambda x: x.shape, params)) | ||||
| 
 | ||||
| model.apply(params, x) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # example of gradient descent | ||||
| 
 | ||||
| # set problem dimensions | ||||
| n_samples: int = 20 | ||||
| x_dim: int = 10 | ||||
| y_dim: int = 5 | ||||
| 
 | ||||
| # generate random ground truth W and b | ||||
| key = random.key(0) | ||||
| k1, k2 = random.split(key) | ||||
| W = random.normal(k1, (x_dim, y_dim)) | ||||
| b = random.normal(k2, (y_dim,)) | ||||
| # store parameters in frozendict pytree | ||||
| true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}}) | ||||
| print(jax.tree_util.tree_map(lambda x: x.shape, true_params)) | ||||
| 
 | ||||
| # Generate samples with additional noise. | ||||
| key_sample, key_noise = random.split(k1) | ||||
| x_samples = random.normal(key_sample, (n_samples, x_dim)) | ||||
| # Wx + b + epsilon (noise) | ||||
| y_samples = ( | ||||
|     jnp.dot(x_samples, W) + b + 0.1 * | ||||
|     random.normal(key_noise, (n_samples, y_dim)) | ||||
| ) | ||||
| print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # get model output | ||||
| @jax.jit | ||||
| def mse(params, x_batched, y_batched): | ||||
|     # define squared loss for a single pair | ||||
|     def squared_error(x, y): | ||||
|         # use model.apply to get model_f(x) | ||||
|         pred = model.apply(params, x) | ||||
|         return jnp.inner(y - pred, y - pred) / 2.0 | ||||
|     # vmap squared_error function across larger array along axis 0 | ||||
|     return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # perform gradient descent | ||||
| learning_rate = 0.3  # Gradient step size. | ||||
| print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples)) | ||||
| # get the gradient function for backprop | ||||
| loss_grad_fn = jax.value_and_grad(mse) | ||||
| 
 | ||||
| 
 | ||||
| @jax.jit | ||||
| def update_params(params, learning_rate, grads): | ||||
|     params = jax.tree_util.tree_map( | ||||
|         lambda p, g: p - learning_rate * g, params, grads) | ||||
|     return params | ||||
| 
 | ||||
| 
 | ||||
| for i in range(101): | ||||
|     # Perform one gradient update. | ||||
|     loss_val, grads = loss_grad_fn(params, x_samples, y_samples) | ||||
|     params = update_params(params, learning_rate, grads) | ||||
|     if i % 10 == 0: | ||||
|         print(f'Loss step {i}: ', loss_val) | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Optimization with Optax | ||||
| 
 | ||||
| # %% | ||||
| # init optax object | ||||
| import optax | ||||
| tx = optax.adam(learning_rate=learning_rate) | ||||
| # we initialize the optimizer with model params | ||||
| opt_state = tx.init(params) | ||||
| loss_grad_fn = jax.value_and_grad(mse) | ||||
| 
 | ||||
| # %% | ||||
| # perform gradient descent with optax optimizer instead of update_params function | ||||
| for i in range(101): | ||||
|     loss_val, grads = loss_grad_fn(params, x_samples, y_samples) | ||||
|     updates, opt_state = tx.update(grads, opt_state) | ||||
|     params = optax.apply_updates(params, updates) | ||||
|     if i % 10 == 0: | ||||
|         print('Loss step {}: '.format(i), loss_val)    | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # Serializing the result - aka saving the model | ||||
| from flax import serialization | ||||
| bytes_output = serialization.to_bytes(params) | ||||
| dict_output = serialization.to_state_dict(params) | ||||
| print('Dict output') | ||||
| print(dict_output) | ||||
| print('Bytes output') | ||||
| print(bytes_output) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # load back from dict output | ||||
| # we first get the param structure | ||||
| model = nn.Dense(features=5) | ||||
| key1, key2 = random.split(random.key(0)) | ||||
| # dummy input to trigger shape inference | ||||
| x = random.normal(key1, (10,)) | ||||
| params = model.init(key2, x)  # initialization call | ||||
| 
 | ||||
| # then we get the saved weights | ||||
| # from state_dict | ||||
| serialization.from_state_dict(params, dict_output) | ||||
| print(params) | ||||
| # from bytes | ||||
| serialization.from_bytes(params, bytes_output) | ||||
| print(params) | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # Defining your own models | ||||
| # A nn.Module subclass is composed of: | ||||
| #  | ||||
| # 1. collection of data fields (argument parameters) | ||||
| # 2. setup() method to setup structures to be called in the forward pass | ||||
| # 3. __call__() function that implements the forward pass | ||||
| # | ||||
| # there is a params structure for the model in the form of a pytree of | ||||
| # parameters | ||||
| 
 | ||||
| # %% | ||||
| # module basics | ||||
| class ExplicitMLP(nn.Module): | ||||
|     # model arguments | ||||
|     # we have a list of feature sizes for the dense layer | ||||
|     features: Sequence[int] | ||||
| 
 | ||||
|     def setup(self): | ||||
|         self.layers = [nn.Dense(feat) for feat in self.features] | ||||
|      | ||||
|     def __call__(self, inputs): | ||||
|         x = inputs | ||||
|         # go through each layer | ||||
|         for i, lyr in enumerate(self.layers): | ||||
|             x = lyr(x) | ||||
|             # add an activating function at the end of each layer | ||||
|             if i != len(self.layers) - 1: | ||||
|                 x = nn.relu(x) | ||||
|         return x | ||||
| 
 | ||||
| # init the model | ||||
| key1, key2 = random.split(random.key(117), 2) | ||||
| # any shape is ok | ||||
| x = random.uniform(key1, (4,4)) | ||||
| print(x) | ||||
| 
 | ||||
| # instantiate the model | ||||
| model = ExplicitMLP(features=[3,4,5]) | ||||
| # init model with an rng key and a test data | ||||
| params = model.init(key2, x) | ||||
| y = model.apply(params, x) | ||||
| 
 | ||||
| print( | ||||
|     'initialized parameter shapes:\n', | ||||
|     jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)), | ||||
| ) | ||||
| print('output:\n', y) | ||||
| 
 | ||||
| # %% | ||||
| # using nn.Module with @nn.compact | ||||
| 
 | ||||
| class SimpleMLP(nn.Module): | ||||
|     features: Sequence[int] | ||||
| 
 | ||||
|     # nn.compact style is declaring submodules inline | ||||
|     @nn.compact | ||||
|     def __call__(self, inputs): | ||||
|         x = inputs | ||||
|         for i, feat in enumerate(self.features): | ||||
|             x = nn.Dense(feat, name=f'layers_{i}')(x) | ||||
|             if i != len(self.features) - 1: | ||||
|                 x = nn.relu(x) | ||||
|             # providing a name is optional though! | ||||
|             # the default autonames would be "Dense_0", "Dense_1", ... | ||||
|         return x | ||||
| 
 | ||||
| key1, key2 = random.split(random.key(117), 2) | ||||
| x = random.uniform(key1, (4,4)) | ||||
| 
 | ||||
| model = SimpleMLP(features=[3,4,5]) | ||||
| params = model.init(key2, x) | ||||
| y = model.apply(params, x) | ||||
| 
 | ||||
| print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) | ||||
| print('output:\n', y) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # creating your own layers | ||||
| # this is a guide to show you how to create your own dense layer as an example | ||||
| class SimpleDense(nn.Module): | ||||
|     features: int | ||||
|     kernel_init: Callable = nn.initializers.lecun_normal() | ||||
|     bias_init: Callable = nn.initializers.zeros_init() | ||||
| 
 | ||||
|     @nn.compact | ||||
|     def __call__(self, inputs): | ||||
|         # utilize the nn.Module.param function to declare a module | ||||
|         # * because params are in __call__, shape inference is possible * | ||||
|         # the 'W' | ||||
|         kernel = self.param( | ||||
|             'kernel',  # name | ||||
|             self.kernel_init,  # Initialization function | ||||
|             (inputs.shape[-1], self.features),  # shape | ||||
|         ) | ||||
|         # 'Wx' | ||||
|         y = jnp.dot(inputs, kernel) | ||||
|         # 'Wx + b' | ||||
|         bias = self.param('bias', self.bias_init, (self.features,)) | ||||
|         y = y + bias | ||||
|         return y | ||||
| 
 | ||||
| key1, key2 = random.split(random.key(0), 2) | ||||
| x = random.uniform(key1, (4,4)) | ||||
| 
 | ||||
| model = SimpleDense(features=3) | ||||
| # init_fn takes (prng key, *init args, **init kwargs) | ||||
| params = model.init(key2, x) | ||||
| y = model.apply(params, x) | ||||
| 
 | ||||
| print('initialized parameters:\n', params) | ||||
| print('output:\n', y) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # handling state in your module | ||||
| # jax does not work well with side-effects and hidden state | ||||
| # jax functions are supposed to be pure functions with determinate input outputs | ||||
| # we introduce "mutable" to separate the function and the state | ||||
| 
 | ||||
| # note: there are 2 collections:  | ||||
| # 'params' for trainable, 'batch_stats' for non-trainable | ||||
| class BiasAdderWithRunningMean(nn.Module): | ||||
|     decay: float = 0.99 | ||||
| 
 | ||||
|     @nn.compact | ||||
|     def __call__(self, x): | ||||
|         # easy pattern to detect if we're initializing via empty variable tree | ||||
|         # check for variable 'mean' in the 'batch_stats' collection | ||||
|         is_initialized = self.has_variable('batch_stats', 'mean') | ||||
|         # declare variable outside model parameters | ||||
|         # declare variable mean in the 'batch_stats' collection | ||||
|         # ra_mean now means batch_stats->mean | ||||
|         ra_mean = self.variable( | ||||
|             'batch_stats',  | ||||
|             'mean',  | ||||
|             lambda s: jnp.zeros(s),  | ||||
|             x.shape[1:]  # dim of data after 0th element e.g. (10,5) -> 5 | ||||
|         ) | ||||
|         bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) | ||||
|         if is_initialized: | ||||
|             ra_mean.value = ( | ||||
|                 self.decay * ra_mean.value  | ||||
|                 + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True) | ||||
|             ) | ||||
| 
 | ||||
|         return x - ra_mean.value + bias | ||||
| 
 | ||||
| 
 | ||||
| key1, key2 = random.split(random.key(0), 2) | ||||
| x = jnp.ones((10,5)) | ||||
| model = BiasAdderWithRunningMean() | ||||
| variables = model.init(key1, x) | ||||
| print('initialized variables:\n', variables) | ||||
| # setting mutable means you also get the updated_state of mutable | ||||
| # hence there is no true hidden state in the function | ||||
| # this serves as a escape hatch for implementing pure functions with state | ||||
| # state is therefore kept away from the module | ||||
| y, updated_state = model.apply(variables, x, mutable=['batch_stats']) | ||||
| print('updated state:\n', updated_state) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| for val in [1.0, 2.0, 3.0]: | ||||
|     x = val * jnp.ones((10, 5)) | ||||
|     y, updated_state = model.apply(variables, x, mutable=['batch_stats']) | ||||
|     old_state, params = flax.core.pop(variables, 'params') | ||||
|     variables = flax.core.freeze({'params': params, **updated_state}) | ||||
|     print('updated state:\n', updated_state)  # Shows only the mutable part | ||||
| 
 | ||||
| # %% | ||||
							
								
								
									
										32
									
								
								t5_jax.py
								
								
								
								
							
							
						
						
									
										32
									
								
								t5_jax.py
								
								
								
								
							|  | @ -96,7 +96,7 @@ split_datasets = load_from_disk(file_path) | |||
| training_size = len(split_datasets['train']) | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 80 | ||||
| num_epochs = 5 | ||||
| batch_size = 384  # 384 is the best | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
|  | @ -314,16 +314,16 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf | |||
| 
 | ||||
|     for idx in batch_idx: | ||||
|         batch = dataset[idx] | ||||
|         batch = {k: np.array(v) for k, v in batch.items()} | ||||
|         batch = {k: jnp.array(v) for k, v in batch.items()} | ||||
| 
 | ||||
|         yield batch | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Model | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # | ||||
| # | ||||
| # | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
|  | @ -442,16 +442,17 @@ def train_step(state, batch, label_smoothing_factor=0.0): | |||
|     num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # true loss = total loss / total samples | ||||
|     loss = jax.lax.psum(loss, "batch") | ||||
|     loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
|     # loss = jax.lax.psum(loss, "batch") | ||||
|     # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
| 
 | ||||
|     # true grad = total grad / total samples | ||||
|     grad = jax.lax.psum(grad, "batch") | ||||
|     grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) | ||||
|     new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) | ||||
| 
 | ||||
|     metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
|     return new_state, metrics | ||||
|     # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
|     # return new_state, metrics | ||||
|     return new_state | ||||
| 
 | ||||
| # Define generation function | ||||
| max_length = ( | ||||
|  | @ -505,19 +506,20 @@ for epoch in epochs: | |||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|         batch = shard(batch) | ||||
|         state, train_metric = p_train_step(state, batch) | ||||
|         train_metrics.append(train_metric) | ||||
|         state = p_train_step(state, batch) | ||||
|         # train_metrics.append(train_metric) | ||||
| 
 | ||||
|     train_time = time.time() - train_start | ||||
| 
 | ||||
|     train_metric = unreplicate(train_metric) | ||||
|     train_metric['loss'].block_until_ready() | ||||
|     # train_metric = unreplicate(train_metric) | ||||
|     # train_metric['loss'].block_until_ready() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     epochs.write( | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " | ||||
|         f"Learning Rate:{train_metric['learning_rate']}, " | ||||
|         # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} |  " | ||||
|         # f"Learning Rate:{train_metric['learning_rate']}, " | ||||
|         f"Last train time: {train_time})" | ||||
|     ) | ||||
| # jax.profiler.stop_trace() | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ from typing import Callable, Optional | |||
| import math | ||||
| 
 | ||||
| # jax.config.update("jax_default_matmul_precision", "tensorfloat32") | ||||
| jax.config.update("jax_default_matmul_precision", "high") | ||||
| jax.config.update("jax_default_matmul_precision", "bfloat16") | ||||
| 
 | ||||
| jax.config.update("jax_enable_x64", False) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue