Feat: increased learning rate for effective large batch size learning
This commit is contained in:
		
							parent
							
								
									aca80720c8
								
							
						
					
					
						commit
						a817fe16cc
					
				|  | @ -1,8 +1,6 @@ | |||
| *.ipynb | ||||
| t5_*/ | ||||
| model_checkpoints/ | ||||
| exports/ | ||||
| modified_t5_model/ | ||||
| traces/ | ||||
| ruff.toml | ||||
| settings.json | ||||
|  |  | |||
|  | @ -0,0 +1 @@ | |||
| __pycache__ | ||||
|  | @ -0,0 +1,163 @@ | |||
| # coding=utf-8 | ||||
| # Copyright 2020, The T5 Authors and HuggingFace Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """T5 model configuration""" | ||||
| 
 | ||||
| from typing import Mapping | ||||
| 
 | ||||
| from transformers import PretrainedConfig | ||||
| from transformers import logging | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class T5Config(PretrainedConfig): | ||||
|     r""" | ||||
|     This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to | ||||
|     instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a | ||||
|     configuration with the defaults will yield a similar configuration to that of the T5 | ||||
|     [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture. | ||||
| 
 | ||||
|     Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||||
|     documentation from [`PretrainedConfig`] for more information. | ||||
| 
 | ||||
|     Arguments: | ||||
|         vocab_size (`int`, *optional*, defaults to 32128): | ||||
|             Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the | ||||
|             `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. | ||||
|         d_model (`int`, *optional*, defaults to 512): | ||||
|             Size of the encoder layers and the pooler layer. | ||||
|         d_kv (`int`, *optional*, defaults to 64): | ||||
|             Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will | ||||
|             be defined as `num_heads * d_kv`. | ||||
|         d_ff (`int`, *optional*, defaults to 2048): | ||||
|             Size of the intermediate feed forward layer in each `T5Block`. | ||||
|         num_layers (`int`, *optional*, defaults to 6): | ||||
|             Number of hidden layers in the Transformer encoder. | ||||
|         num_decoder_layers (`int`, *optional*): | ||||
|             Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. | ||||
|         num_heads (`int`, *optional*, defaults to 8): | ||||
|             Number of attention heads for each attention layer in the Transformer encoder. | ||||
|         relative_attention_num_buckets (`int`, *optional*, defaults to 32): | ||||
|             The number of buckets to use for each attention layer. | ||||
|         relative_attention_max_distance (`int`, *optional*, defaults to 128): | ||||
|             The maximum distance of the longer sequences for the bucket separation. | ||||
|         dropout_rate (`float`, *optional*, defaults to 0.1): | ||||
|             The ratio for all dropout layers. | ||||
|         classifier_dropout (`float`, *optional*, defaults to 0.0): | ||||
|             The dropout ratio for classifier. | ||||
|         layer_norm_eps (`float`, *optional*, defaults to 1e-6): | ||||
|             The epsilon used by the layer normalization layers. | ||||
|         initializer_factor (`float`, *optional*, defaults to 1): | ||||
|             A factor for initializing all weight matrices (should be kept to 1, used internally for initialization | ||||
|             testing). | ||||
|         feed_forward_proj (`string`, *optional*, defaults to `"relu"`): | ||||
|             Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the | ||||
|             `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. | ||||
|         use_cache (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not the model should return the last key/values attentions (not used by all models). | ||||
|     """ | ||||
| 
 | ||||
|     model_type = "t5" | ||||
|     keys_to_ignore_at_inference = ["past_key_values"] | ||||
|     attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         vocab_size=32128, # vocab size here | ||||
|         d_model=512, | ||||
|         d_kv=64, | ||||
|         d_ff=2048, | ||||
|         num_layers=6, | ||||
|         num_decoder_layers=None, | ||||
|         num_heads=8, | ||||
|         relative_attention_num_buckets=32, | ||||
|         relative_attention_max_distance=128, | ||||
|         dropout_rate=0.1, | ||||
|         layer_norm_epsilon=1e-6, | ||||
|         initializer_factor=1.0, | ||||
|         feed_forward_proj="relu", | ||||
|         is_encoder_decoder=True, | ||||
|         use_cache=True, | ||||
|         pad_token_id=0, | ||||
|         eos_token_id=1, | ||||
|         classifier_dropout=0.0, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         self.vocab_size = vocab_size | ||||
|         self.d_model = d_model | ||||
|         self.d_kv = d_kv | ||||
|         self.d_ff = d_ff | ||||
|         self.num_layers = num_layers | ||||
|         self.num_decoder_layers = ( | ||||
|             num_decoder_layers if num_decoder_layers is not None else self.num_layers | ||||
|         )  # default = symmetry | ||||
|         self.num_heads = num_heads | ||||
|         self.relative_attention_num_buckets = relative_attention_num_buckets | ||||
|         self.relative_attention_max_distance = relative_attention_max_distance | ||||
|         self.dropout_rate = dropout_rate | ||||
|         self.classifier_dropout = classifier_dropout | ||||
|         self.layer_norm_epsilon = layer_norm_epsilon | ||||
|         self.initializer_factor = initializer_factor | ||||
|         self.feed_forward_proj = feed_forward_proj | ||||
|         self.use_cache = use_cache | ||||
|         self.use_bfloat16 = True | ||||
| 
 | ||||
|         act_info = self.feed_forward_proj.split("-") | ||||
|         self.dense_act_fn = act_info[-1] | ||||
|         self.is_gated_act = act_info[0] == "gated" | ||||
| 
 | ||||
|         if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: | ||||
|             raise ValueError( | ||||
|                 f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " | ||||
|                 "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " | ||||
|                 "'gated-gelu' or 'relu'" | ||||
|             ) | ||||
| 
 | ||||
|         # for backwards compatibility | ||||
|         if feed_forward_proj == "gated-gelu": | ||||
|             self.dense_act_fn = "gelu_new" | ||||
| 
 | ||||
|         super().__init__( | ||||
|             pad_token_id=pad_token_id, | ||||
|             eos_token_id=eos_token_id, | ||||
|             is_encoder_decoder=is_encoder_decoder, | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| # class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): | ||||
| #     @property | ||||
| #     def inputs(self) -> Mapping[str, Mapping[int, str]]: | ||||
| #         common_inputs = { | ||||
| #             "input_ids": {0: "batch", 1: "encoder_sequence"}, | ||||
| #             "attention_mask": {0: "batch", 1: "encoder_sequence"}, | ||||
| #         } | ||||
| #         if self.use_past: | ||||
| #             common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" | ||||
| #             common_inputs["decoder_input_ids"] = {0: "batch"} | ||||
| #             common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} | ||||
| #         else: | ||||
| #             common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} | ||||
| #             common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} | ||||
| #  | ||||
| #         if self.use_past: | ||||
| #             self.fill_with_past_key_values_(common_inputs, direction="inputs") | ||||
| #  | ||||
| #         return common_inputs | ||||
| #  | ||||
| #     @property | ||||
| #     def default_onnx_opset(self) -> int: | ||||
| #         return 13 | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										175
									
								
								t5_jax.py
								
								
								
								
							
							
						
						
									
										175
									
								
								t5_jax.py
								
								
								
								
							|  | @ -25,6 +25,7 @@ import numpy as np | |||
| from functools import partial | ||||
| from typing import Callable, Optional | ||||
| import math | ||||
| import flax.linen as nn | ||||
| 
 | ||||
| # jax.config.update("jax_default_matmul_precision", "tensorfloat32") | ||||
| jax.config.update("jax_default_matmul_precision", "bfloat16") | ||||
|  | @ -83,7 +84,7 @@ os.environ.update({ | |||
|     "NCCL_LL128_BUFFSIZE": "-2", | ||||
|     "NCCL_LL_BUFFSIZE": "-2", | ||||
|     "NCCL_PROTO": "SIMPLE,LL,LL128", | ||||
|     "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.99", | ||||
|     "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.8", | ||||
|     # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" | ||||
|  }) | ||||
| 
 | ||||
|  | @ -103,14 +104,14 @@ except (LookupError, OSError): | |||
| # %% | ||||
| # config options | ||||
| file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval' | ||||
| save_path = 't5_5e_1_pmap' | ||||
| save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/original/' | ||||
| # file_path = 'combined_data' | ||||
| split_datasets = load_from_disk(file_path) | ||||
| training_size = len(split_datasets['train']) | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 5 | ||||
| batch_size = 32  # 384 is the best | ||||
| num_epochs = 40 | ||||
| batch_size = 64  # 384 is the best | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
|  | @ -120,7 +121,7 @@ steps_per_epoch = training_size // train_batch_size | |||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 2e-5 | ||||
| learning_rate = 2e-4 | ||||
| 
 | ||||
| weight_decay = 0.01 | ||||
| adam_beta1 = 0.9 | ||||
|  | @ -129,7 +130,7 @@ adam_epsilon = 1e-8 | |||
| label_smoothing_factor = 0.0 | ||||
| 
 | ||||
| num_beams = 1 | ||||
| val_max_target_length = 86 | ||||
| val_max_target_length = 128 | ||||
| 
 | ||||
| predict_with_generate = True | ||||
| 
 | ||||
|  | @ -143,7 +144,7 @@ additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", | |||
| # Add the additional special tokens to the tokenizer | ||||
| tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens}) | ||||
| 
 | ||||
| max_length = 86 | ||||
| max_length = 128 | ||||
| 
 | ||||
| # %% | ||||
| len(tokenizer) | ||||
|  | @ -176,51 +177,64 @@ from transformers import FlaxT5ForConditionalGeneration | |||
| from transformers import T5Config | ||||
| 
 | ||||
| config = T5Config() | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # If you want don't want to cast certain parameters (for example layer norm bias and scale) | ||||
| # then pass the mask as follows | ||||
| from flax import traverse_util | ||||
| 
 | ||||
| model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base") | ||||
| # useful for transformer model | ||||
| # model.enable_gradient_checkpointing() | ||||
| model = FlaxT5ForConditionalGeneration.from_pretrained( | ||||
|     "t5-base", | ||||
|     dtype=jnp.bfloat16, | ||||
|     gradient_checkpointing=True | ||||
| ) | ||||
| params = model.params | ||||
| 
 | ||||
| # enable bf16 except for layer_norm | ||||
| # flat_params = traverse_util.flatten_dict(model.params) | ||||
| # mask = { | ||||
| #     path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params | ||||
| # } | ||||
| # mask = traverse_util.unflatten_dict(mask) | ||||
| # # borrowed from transformers modeling_flax_utils | ||||
| # def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: | ||||
| #     """ | ||||
| #     Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. | ||||
| #     """ | ||||
| #  | ||||
| #     # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 | ||||
| #     def conditional_cast(param): | ||||
| #         if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): | ||||
| #             param = param.astype(dtype) | ||||
| #         return param | ||||
| #  | ||||
| #     if mask is None: | ||||
| #         return jax.tree_util.tree_map(conditional_cast, params) | ||||
| #  | ||||
| #     flat_params = traverse_util.flatten_dict(params) | ||||
| #     flat_mask, _ = jax.tree_util.tree_flatten(mask) | ||||
| #  | ||||
| #     for masked, key in zip(flat_mask, sorted(flat_params.keys())): | ||||
| #         if masked: | ||||
| #             flat_params[key] = conditional_cast(flat_params[key]) | ||||
| #      | ||||
| #     return traverse_util.unflatten_dict(flat_params) | ||||
| #  | ||||
| # # Cast parameters to bfloat16 if desired | ||||
| # # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params) | ||||
| # # instead of casting the whole thing, we cast only certain parts of the tree | ||||
| # params = cast_floating_to(model.params, jnp.bfloat16, mask) | ||||
| # enable bf16 | ||||
| # enable only for dense, some transformer sections, and shared | ||||
| def create_mask_for_layer_norm(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     mask = { | ||||
|         # path: not ( | ||||
|         #     (path[-2] == "layer_norm" and path[-1] == "weight") or  | ||||
|         #     (path[-2] == "final_layer_norm" and path[-1] == "weight") or  | ||||
|         #     (path[-2] == "o" and path[-1] == "kernel") | ||||
|         # )  | ||||
|         # for path in flat_params | ||||
|         path: ( | ||||
|             (path[-2] == "wi" and path[-1] == "weight") or | ||||
|             (path[-2] == "wo" and path[-1] == "weight") or | ||||
|             (path[-2] == "k" and path[-1] == "kernel") or | ||||
|             (path[-2] == "q" and path[-1] == "kernel") or | ||||
|             (path[-2] == "v" and path[-1] == "kernel") or | ||||
|             (path[-2] == "shared" and path[-1] == "embedding") | ||||
|         ) for path in flat_params | ||||
|     } | ||||
|     mask = traverse_util.unflatten_dict(mask) | ||||
|     return mask | ||||
| 
 | ||||
| # borrowed from transformers modeling_flax_utils | ||||
| def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: | ||||
|     """ | ||||
|     Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. | ||||
|     """ | ||||
| 
 | ||||
|     # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 | ||||
|     def conditional_cast(param): | ||||
|         if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): | ||||
|             param = param.astype(dtype) | ||||
|         return param | ||||
| 
 | ||||
|     if mask is None: | ||||
|         return jax.tree_util.tree_map(conditional_cast, params) | ||||
| 
 | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     flat_mask, _ = jax.tree_util.tree_flatten(mask) | ||||
| 
 | ||||
|     for masked, key in zip(flat_mask, sorted(flat_params.keys())): | ||||
|         if masked: | ||||
|             flat_params[key] = conditional_cast(flat_params[key]) | ||||
|          | ||||
|     return traverse_util.unflatten_dict(flat_params) | ||||
| 
 | ||||
| mask = create_mask_for_layer_norm(params) | ||||
| # override params with bfloat version | ||||
| params= cast_floating_to(params, jnp.bfloat16, mask) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
|  | @ -307,31 +321,6 @@ token_datasets.set_format( | |||
|     'labels', 'decoder_input_ids',  | ||||
|     'decoder_attention_mask'] | ||||
| ) | ||||
| # %% | ||||
| # check values | ||||
| for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']: | ||||
|     int_array = train_dataset[name] | ||||
|     if np.all((int_array >= 0) & (int_array <= 65535)): | ||||
|         uint16_array = int_array.astype(np.uint16) | ||||
|     else: | ||||
|         raise ValueError("Values are out of range for uint16") | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| from datasets import ClassLabel, Value, Sequence | ||||
| features = train_dataset.features.copy() | ||||
| features['input_ids'] = Sequence(Value('uint16')) | ||||
| features['attention_mask'] = Sequence(Value('bool')) | ||||
| features['labels'] = Sequence(Value('uint16')) | ||||
| features['decoder_input_ids'] = Sequence(Value('uint16')) | ||||
| features['decoder_attention_mask'] = Sequence(Value('bool')) | ||||
| train_dataset = train_dataset.cast(features) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # temp | ||||
| print('data type check: ', train_dataset['decoder_attention_mask'].dtype) | ||||
| 
 | ||||
| # %% | ||||
| def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): | ||||
|  | @ -355,17 +344,11 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf | |||
| 
 | ||||
|     for idx in batch_idx: | ||||
|         batch = dataset[idx] | ||||
|         batch = {k: jnp.array(v) for k, v in batch.items()} | ||||
|         batch = {k: v for k, v in batch.items()} | ||||
| 
 | ||||
|         yield batch | ||||
| 
 | ||||
| 
 | ||||
| # %% [markdown] | ||||
| # # Model | ||||
| # | ||||
| # | ||||
| # | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # Initialize our training | ||||
|  | @ -406,7 +389,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn( | |||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | ||||
|     layer_norm_candidates = ["final_layer_norm", "layer_norm"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|  | @ -437,13 +420,10 @@ class TrainState(train_state.TrainState): | |||
|     def replicate(self): | ||||
|         return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) | ||||
| 
 | ||||
| # set bf16 for model params | ||||
| # model.params = model.to_bf16(model.params) | ||||
| # Cast parameters to bfloat16 if desired | ||||
| # params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) | ||||
| 
 | ||||
| # Setup train state | ||||
| state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) | ||||
| # input all the state here | ||||
| state = TrainState.create(apply_fn=model.__call__, params=params, tx=adamw, dropout_rng=dropout_rng) | ||||
| 
 | ||||
| # label smoothed cross entropy | ||||
| def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | ||||
|  | @ -485,17 +465,17 @@ def train_step(state, batch, label_smoothing_factor=0.0): | |||
|     num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # true loss = total loss / total samples | ||||
|     # loss = jax.lax.psum(loss, "batch") | ||||
|     # loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
|     loss = jax.lax.psum(loss, "batch") | ||||
|     loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) | ||||
| 
 | ||||
|     # true grad = total grad / total samples | ||||
|     grad = jax.lax.psum(grad, "batch") | ||||
|     grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) | ||||
|     new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) | ||||
| 
 | ||||
|     # metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
|     # return new_state, metrics | ||||
|     return new_state | ||||
|     metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
|     return new_state, metrics | ||||
|     # return new_state | ||||
| 
 | ||||
| # Define generation function | ||||
| max_length = ( | ||||
|  | @ -549,25 +529,24 @@ for epoch in epochs: | |||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|         batch = shard(batch) | ||||
|         state = p_train_step(state, batch) | ||||
|         state, train_metric = p_train_step(state, batch) | ||||
|         # train_metrics.append(train_metric) | ||||
| 
 | ||||
|     train_time = time.time() - train_start | ||||
| 
 | ||||
|     # train_metric = unreplicate(train_metric) | ||||
|     # train_metric['loss'].block_until_ready() | ||||
|     train_metric = unreplicate(train_metric) | ||||
|     train_metric['loss'].block_until_ready() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     epochs.write( | ||||
|         # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, " | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} |  " | ||||
|         # f"Learning Rate:{train_metric['learning_rate']}, " | ||||
|         f"Learning Rate:{train_metric['learning_rate']}, " | ||||
|         f"Last train time: {train_time})" | ||||
|     ) | ||||
| # jax.profiler.stop_trace() | ||||
| # %% | ||||
| 
 | ||||
| output_dir = save_path | ||||
| # save checkpoint after each epoch and push checkpoint to the hub | ||||
| if jax.process_index() == 0: | ||||
|  |  | |||
|  | @ -66,7 +66,8 @@ import orbax.checkpoint as ocp | |||
| 
 | ||||
| # data_path = f"../make_data/select_db/data_mapping_filtered.csv" | ||||
| # data_path = f"../make_data_2/select_db/dataset/1/train_all.csv" | ||||
| data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' | ||||
| model_name_or_path = "./model_checkpoints/simple"  # Replace with your specific model name | ||||
| data_path = '/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv' | ||||
| # data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv' | ||||
| 
 | ||||
| # Ensure to include 'ships_idx' in the fields list | ||||
|  | @ -97,7 +98,6 @@ test_dataset = Dataset.from_list(process_df(df)) | |||
| # from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration | ||||
| from transformers import FlaxT5ForConditionalGeneration | ||||
| # model_name_or_path = "./t5_80_1"  # Replace with your specific model name | ||||
| model_name_or_path = "./model_checkpoints/simple_test"  # Replace with your specific model name | ||||
| model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path) | ||||
| params = model.params | ||||
| 
 | ||||
|  | @ -275,11 +275,12 @@ df['p_property_correct'] = df['p_property'] == df['property'] | |||
| print("thing accuracy", sum(df['p_thing_correct'])/len(df)) | ||||
| print("property accuracy", sum(df['p_property_correct'])/len(df)) | ||||
| print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df)) | ||||
| # %% | ||||
| df[~df["p_property_correct"]] | ||||
| 
 | ||||
| # %% | ||||
| df['p_thing'] | ||||
| # df[~df["p_property_correct"]] | ||||
| 
 | ||||
| # %% | ||||
| # df['p_thing'] | ||||
| # %% | ||||
| # Save the DataFrame as a Parquet file (using pyarrow or fastparquet) | ||||
| # df.to_parquet("exports/output_file.parquet", engine="pyarrow")  # or use engine="fastparquet" | ||||
|  |  | |||
|  | @ -0,0 +1,610 @@ | |||
| # %% | ||||
| import os | ||||
| # Set this to True to run the model on CPU only. | ||||
| USE_CPU_ONLY = False | ||||
| 
 | ||||
| flags = os.environ.get("XLA_FLAGS", "") | ||||
| if USE_CPU_ONLY: | ||||
|     flags += " --xla_force_host_platform_device_count=4"  # Simulate 8 devices | ||||
|     # Enforce CPU-only execution | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||||
|     os.environ["JAX_PLATFORMS"] = "cpu" | ||||
| else: | ||||
|     # GPU flags | ||||
|     flags = ( | ||||
|         '--xla_gpu_enable_triton_softmax_fusion=true ' | ||||
|         '--xla_gpu_triton_gemm_any=True ' | ||||
|         # '--xla_gpu_enable_async_collectives=true ' | ||||
|         '--xla_gpu_enable_latency_hiding_scheduler=true ' | ||||
|         '--xla_gpu_enable_highest_priority_async_stream=true ' | ||||
|     ) | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | ||||
| 
 | ||||
| os.environ["XLA_FLAGS"] = flags | ||||
| os.environ.update({ | ||||
|     "TOKENIZERS_PARALLELISM" : "false", | ||||
|     "CUDA_DEVICE_MAX_CONNECTIONS" : "1", | ||||
|     "NCCL_LL128_BUFFSIZE": "-2", | ||||
|     "NCCL_LL_BUFFSIZE": "-2", | ||||
|     "NCCL_PROTO": "SIMPLE,LL,LL128", | ||||
|     "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", | ||||
|     # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" | ||||
|  }) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| import functools | ||||
| from functools import partial | ||||
| from pprint import pprint | ||||
| from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union | ||||
| 
 | ||||
| import flax.linen as nn | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import numpy as np | ||||
| from jax.experimental.shard_map import shard_map | ||||
| from jax.sharding import Mesh, NamedSharding | ||||
| # from jax.experimental.pjit import pjit  # superseded by jax.jit | ||||
| from jax.experimental import mesh_utils | ||||
| from jax.sharding import PartitionSpec | ||||
| from ml_collections import ConfigDict | ||||
| import optax | ||||
| import logging | ||||
| import time | ||||
| from datasets import Dataset, load_from_disk | ||||
| 
 | ||||
| from flax import jax_utils, traverse_util | ||||
| from flax.training import train_state | ||||
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| from flax.core.frozen_dict import freeze, unfreeze, FrozenDict | ||||
| import flax.core | ||||
| 
 | ||||
| # model checkpointing and saving utilities | ||||
| from flax import linen as nn | ||||
| from flax.training import checkpoints, train_state | ||||
| from flax import struct, serialization | ||||
| 
 | ||||
| from parallel.partitions import set_partitions | ||||
| 
 | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| from parallel.dataload import DataPrepare | ||||
| 
 | ||||
| # for memory tracking | ||||
| # from jax_smi import initialise_tracking | ||||
| # initialise_tracking() | ||||
| 
 | ||||
| 
 | ||||
| PyTree = Any | ||||
| Metrics = Dict[str, Tuple[jax.Array, ...]] | ||||
| 
 | ||||
| if USE_CPU_ONLY: | ||||
|     jax.config.update('jax_platform_name', 'cpu') | ||||
| else: | ||||
|     jax.config.update("jax_default_matmul_precision", "bfloat16") | ||||
| 
 | ||||
| jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") | ||||
| jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) | ||||
| jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| ## get platform type | ||||
| from jax.extend.backend import get_backend | ||||
| print(get_backend().platform) | ||||
| print(jax.devices()) | ||||
|   | ||||
| # %% | ||||
| # config options | ||||
| file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' | ||||
| save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' | ||||
| # file_path = 'combined_data' | ||||
| split_datasets = load_from_disk(file_path) | ||||
| training_size = len(split_datasets['train']) | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 5 | ||||
| batch_size = 64 | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
| per_device_eval_batch_size = batch_size | ||||
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||||
| steps_per_epoch = training_size // train_batch_size | ||||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 5e-5 | ||||
| 
 | ||||
| weight_decay = 0.01 | ||||
| adam_beta1 = 0.9 | ||||
| adam_beta2 = 0.999 | ||||
| adam_epsilon = 1e-8 | ||||
| label_smoothing_factor = 0.0 | ||||
| num_beams = 1 | ||||
| val_max_target_length = 128 | ||||
| predict_with_generate = True | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # prepare data | ||||
| # init object | ||||
| # e.g. Config | ||||
| print("preparing data") | ||||
| data_config = ConfigDict( | ||||
|     dict( | ||||
|         max_length=128, | ||||
|         pad_token_id=0, | ||||
|         decoder_start_token_id=0 | ||||
|     ) | ||||
| ) | ||||
| 
 | ||||
| dataprep = DataPrepare(split_datasets['train'], data_config) | ||||
| # # example usage | ||||
| # # %% | ||||
| seed = 117 | ||||
| rng = jax.random.PRNGKey(seed) | ||||
| train_loader = dataprep.data_loader(rng, batch_size=batch_size) | ||||
| batch = next(iter(train_loader)) | ||||
| 
 | ||||
| # %% | ||||
| # from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom | ||||
| from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model | ||||
| main_model = custom_model.from_pretrained( | ||||
|     "t5-base", | ||||
|     dtype=jnp.bfloat16, | ||||
|     gradient_checkpointing=True,  | ||||
| ) | ||||
| params = main_model.params | ||||
| # pretrained_params = model.params | ||||
| model = main_model.module | ||||
| 
 | ||||
| # %% | ||||
| # # testing config hashability | ||||
| # # some explanation: | ||||
| # # The PreTrainedModel class loads a T5Config model that is not hashable because | ||||
| # # it is a complicated class that pretends to be a dataclass. | ||||
| # # The solution is to extract a dict from it, then make a ConfigDict from | ||||
| # # ml_collections library so that we can get values via the "." operator. | ||||
| # # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to | ||||
| # # modify the config before passing to the next layer | ||||
| # from transformers import T5Config | ||||
| # from t5_model.configuration_t5 import FrozenT5Config | ||||
| # from ml_collections import ConfigDict, FrozenConfigDict | ||||
| #  | ||||
| # config = T5Config.from_pretrained("t5-base").to_dict() | ||||
| # config.pop('architectures') | ||||
| # config.pop('id2label') | ||||
| # # test if it works  | ||||
| # frozen_config = FrozenConfigDict(config) | ||||
| # # test hash | ||||
| # hash(frozen_config) | ||||
| 
 | ||||
| # %% | ||||
| # # print model | ||||
| # rng, input_rng = jax.random.split(rng) | ||||
| # model.tabulate( | ||||
| #     input_rng, | ||||
| #     input_ids=batch['input_ids'], | ||||
| #     attention_mask=batch['attention_mask'], | ||||
| #     decoder_input_ids=batch['decoder_input_ids'], | ||||
| #     decoder_attention_mask=batch['decoder_attention_mask'], | ||||
| #     console_kwargs={"force_jupyter": True} | ||||
| # ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # create mesh | ||||
| print("creating mesh") | ||||
| device_mesh = mesh_utils.create_device_mesh((2,2)) | ||||
| print(device_mesh) | ||||
| 
 | ||||
| mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) | ||||
| print(mesh) | ||||
| 
 | ||||
| def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: | ||||
|     return NamedSharding(mesh, pspec, memory_kind="device") | ||||
| 
 | ||||
| x_sharding = mesh_sharding(PartitionSpec('data', None))  # replicate across data axis | ||||
| model_sharding=mesh_sharding(PartitionSpec(None, 'model')) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # optimizers | ||||
| 
 | ||||
| def create_learning_rate_fn( | ||||
|     train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | ||||
| ) -> Callable[[int], jnp.ndarray]: | ||||
|     """Returns a linear warmup, linear_decay learning rate function.""" | ||||
|     steps_per_epoch = train_ds_size // train_batch_size | ||||
|     num_train_steps = steps_per_epoch * num_train_epochs | ||||
|     warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | ||||
|     decay_fn = optax.linear_schedule( | ||||
|         init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | ||||
|     ) | ||||
|     schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | ||||
|     return schedule_fn | ||||
| 
 | ||||
| 
 | ||||
| # Create learning rate schedule | ||||
| linear_decay_lr_schedule_fn = create_learning_rate_fn( | ||||
|     training_size, | ||||
|     train_batch_size, | ||||
|     num_train_epochs, | ||||
|     warmup_steps, | ||||
|     learning_rate, | ||||
| ) | ||||
| 
 | ||||
| # We use Optax's "masking" functionality to not apply weight decay | ||||
| # to bias and LayerNorm scale parameters. decay_mask_fn returns a | ||||
| # mask boolean with the same structure as the parameters. | ||||
| # The mask is True for parameters that should be decayed. | ||||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|         for layer in flat_params.keys() | ||||
|         if layer_norm_name in "".join(layer).lower() | ||||
|     } | ||||
|     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | ||||
|     return traverse_util.unflatten_dict(flat_mask) | ||||
| 
 | ||||
| # create adam optimizer | ||||
| adamw = optax.adamw( | ||||
|     learning_rate=linear_decay_lr_schedule_fn, | ||||
|     b1=adam_beta1, | ||||
|     b2=adam_beta2, | ||||
|     eps=adam_epsilon, | ||||
|     weight_decay=weight_decay, | ||||
|     mask=decay_mask_fn, | ||||
| ) | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| print("compile") | ||||
| 
 | ||||
| # enable bf16 | ||||
| # enable only for dense, some transformer sections, and shared | ||||
| def create_mask_for_layer_norm(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     mask = { | ||||
|         # path: not ( | ||||
|         #     (path[-2] == "layer_norm" and path[-1] == "weight") or  | ||||
|         #     (path[-2] == "final_layer_norm" and path[-1] == "weight") or  | ||||
|         #     (path[-2] == "o" and path[-1] == "kernel") | ||||
|         # )  | ||||
|         # for path in flat_params | ||||
|         path: ( | ||||
|             (path[-2] == "wi" and path[-1] == "weight") or | ||||
|             (path[-2] == "wo" and path[-1] == "weight") or | ||||
|             (path[-2] == "k" and path[-1] == "kernel") or | ||||
|             (path[-2] == "q" and path[-1] == "kernel") or | ||||
|             (path[-2] == "v" and path[-1] == "kernel") or | ||||
|             (path[-2] == "shared" and path[-1] == "embedding") | ||||
|         ) for path in flat_params | ||||
|     } | ||||
|     mask = traverse_util.unflatten_dict(mask) | ||||
|     return mask | ||||
| 
 | ||||
| # borrowed from transformers modeling_flax_utils | ||||
| def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: | ||||
|     """ | ||||
|     Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. | ||||
|     """ | ||||
| 
 | ||||
|     # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 | ||||
|     def conditional_cast(param): | ||||
|         if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): | ||||
|             param = param.astype(dtype) | ||||
|         return param | ||||
| 
 | ||||
|     if mask is None: | ||||
|         return jax.tree_util.tree_map(conditional_cast, params) | ||||
| 
 | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     flat_mask, _ = jax.tree_util.tree_flatten(mask) | ||||
| 
 | ||||
|     for masked, key in zip(flat_mask, sorted(flat_params.keys())): | ||||
|         if masked: | ||||
|             flat_params[key] = conditional_cast(flat_params[key]) | ||||
|          | ||||
|     return traverse_util.unflatten_dict(flat_params) | ||||
| 
 | ||||
| # create init_fn to produce sharded state | ||||
| def init_fn(params, model, optimizer): | ||||
|     # do be careful with the model init | ||||
|     # imported models might have complicated init methods | ||||
| 
 | ||||
|     # mask = create_mask_for_layer_norm(params) | ||||
|     # override params with bfloat version | ||||
|     # params= cast_floating_to(params, jnp.bfloat16, mask) | ||||
| 
 | ||||
|     state = train_state.TrainState.create(  # Create a `TrainState`. | ||||
|         apply_fn=model.apply, | ||||
|         params=params, | ||||
|         tx=optimizer) | ||||
|     return state | ||||
| 
 | ||||
| 
 | ||||
| abstract_variables = jax.eval_shape( | ||||
|     functools.partial(init_fn, model=model, optimizer=adamw), params) | ||||
| 
 | ||||
| # jax.sharding: describes how a jax.Array is laid out across devices | ||||
| state_sharding = nn.get_sharding(abstract_variables, mesh) | ||||
| # print(state_sharding) | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # replace the params tree with the new modified tree | ||||
| # create partitions for model | ||||
| from parallel.partitions import set_partitions | ||||
| # set_partitions freezes the params on return | ||||
| model_part_spec = set_partitions(unfreeze(params)) | ||||
| # p is already a partition spec | ||||
| model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) | ||||
| 
 | ||||
| # get pspec for opt_state | ||||
| def get_opt_spec(x): | ||||
|     if isinstance(x, dict): | ||||
|         return unfreeze(model_named_sharding) | ||||
|     # return an empty partspec | ||||
|     return mesh_sharding((PartitionSpec())) | ||||
| 
 | ||||
| # this function replaces the empty model params spec with the 'model_named_shard' | ||||
| state_sharding = jax.tree.map( | ||||
|     get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| jit_init_fn = jax.jit( | ||||
|     init_fn, | ||||
|     static_argnames=('model', 'optimizer'),  # skip model and optimizer | ||||
|     in_shardings=mesh_sharding(PartitionSpec()),  # we don't shard params explicitly | ||||
|     out_shardings=state_sharding  # but returned initialized_state is sharded | ||||
| ) | ||||
| initialized_state = jit_init_fn(params, model, adamw) | ||||
| 
 | ||||
| # %% | ||||
| # train step | ||||
| def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | ||||
|     """ | ||||
|     The label smoothing implementation is adapted from Flax's official example: | ||||
|     https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 | ||||
|     """ | ||||
|     vocab_size = logits.shape[-1] | ||||
|     confidence = 1.0 - label_smoothing_factor | ||||
|     low_confidence = (1.0 - confidence) / (vocab_size - 1) | ||||
|     normalizing_constant = -( | ||||
|         confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) | ||||
|     ) | ||||
|     soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) | ||||
|     logits = jnp.asarray(logits, dtype=jnp.float32) | ||||
|     logits = logits.astype(jnp.float32) | ||||
|     soft_labels = soft_labels.astype(jnp.float32) | ||||
|     loss = optax.softmax_cross_entropy(logits, soft_labels) | ||||
|     loss = loss - normalizing_constant | ||||
| 
 | ||||
|     # ignore padded tokens from loss | ||||
|     loss = loss * padding_mask | ||||
|     loss = loss.mean() | ||||
|     # num_labels = padding_mask.mean() | ||||
|     return loss # , num_labels | ||||
| 
 | ||||
| # %% | ||||
| # gradient accumulation | ||||
| def accumulate_gradients_loop( | ||||
|     state, | ||||
|     batch, | ||||
|     minibatch_size: int, | ||||
|     loss_fn: Callable, | ||||
| ) -> Tuple[PyTree, Metrics]: | ||||
|     """Calculate gradients and metrics for a batch using gradient accumulation. | ||||
| 
 | ||||
|     Args: | ||||
|         state: Current training state. | ||||
|         batch: Full training batch. | ||||
|         rng: Random number generator to use. | ||||
|         num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps. | ||||
|         loss_fn: Loss function to calculate gradients and metrics. | ||||
| 
 | ||||
|     Returns: | ||||
|         Tuple with accumulated gradients and metrics over the minibatches. | ||||
|     """ | ||||
|     batch_size = batch['input_ids'].shape[0] | ||||
|     # minibatch_size = batch_size // num_minibatches | ||||
|     num_minibatches = batch_size // minibatch_size | ||||
|     # Define gradient function for single minibatch. | ||||
|     # If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned. | ||||
|     # otherwise it returns (value, gradient), where value is the actual output | ||||
|     # of the function, hence the "value" of the namesake | ||||
|     grad_fn = jax.value_and_grad(loss_fn, has_aux=False) | ||||
|     # Prepare loop variables. | ||||
|     grads = None | ||||
|     metrics = None | ||||
|     for minibatch_idx in range(num_minibatches): | ||||
|         with jax.named_scope(f"minibatch_{minibatch_idx}"): | ||||
|             # Split the batch into minibatches. | ||||
|             start = minibatch_idx * minibatch_size | ||||
|             end = start + minibatch_size | ||||
|             minibatch = jax.tree.map(lambda x: x[start:end], batch)  # noqa: B023 | ||||
|             # Calculate gradients and metrics for the minibatch. | ||||
|             # missing value is mean loss of batch | ||||
|             loss, step_grads = grad_fn( | ||||
|                 state.params, minibatch | ||||
|             ) | ||||
|             with jax.named_scope("sync_metrics"): | ||||
|                 step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
| 
 | ||||
|             # Accumulate gradients and metrics across minibatches. | ||||
|             if grads is None: | ||||
|                 grads = step_grads | ||||
|                 metrics = step_metrics | ||||
|             else: | ||||
|                 # accumulation adder | ||||
|                 grads = jax.tree.map(jnp.add, grads, step_grads) | ||||
|                 metrics = jax.tree.map(jnp.add, metrics, step_metrics) | ||||
|     # Average gradients over minibatches. | ||||
|     grads = jax.tree.map(lambda g: g / num_minibatches, grads) | ||||
|     return grads, metrics | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # single device code annotated with jax.jit | ||||
| @functools.partial( | ||||
|     jax.jit, | ||||
|     # state is state_sharding initialized from init_fn | ||||
|     # x_sharding is data sharded explicitly later | ||||
|     in_shardings=(state_sharding, x_sharding), | ||||
|     out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), | ||||
|     donate_argnames=('state'), | ||||
| ) | ||||
| def train_step(state, batch): | ||||
| 
 | ||||
|     label_smoothing_factor=0.0 | ||||
|     # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | ||||
|     def compute_loss(params, batch): | ||||
|         # check constraints | ||||
|         # frozen dict not allowed as sharding object | ||||
|         params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) | ||||
|         batch = jax.lax.with_sharding_constraint(batch, x_sharding) | ||||
|         logits = state.apply_fn( | ||||
|             {'params': params}, | ||||
|             input_ids=batch['input_ids'], | ||||
|             attention_mask=batch['attention_mask'], | ||||
|             decoder_input_ids=batch['decoder_input_ids'], | ||||
|             decoder_attention_mask=batch['decoder_attention_mask'], | ||||
|         )[0]  # zero because output is some structure, where first is the logit | ||||
|         # use labels here | ||||
|         # loss, num_labels = loss_fn( | ||||
|         loss = loss_fn( | ||||
|             logits, | ||||
|             batch["labels"], | ||||
|             batch["decoder_attention_mask"], | ||||
|             label_smoothing_factor) | ||||
|         return loss # , num_labels | ||||
| 
 | ||||
|     # # compute gradients through computational graph | ||||
|     # # allow values to pass through | ||||
|     # grad_fn = jax.value_and_grad(compute_loss, has_aux=False) | ||||
|     # (loss), grad = grad_fn(state.params, batch) | ||||
|     # # num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
| 
 | ||||
|     # new_state = state.apply_gradients(grads=grad) | ||||
|     # with jax.named_scope("sync_metrics"): | ||||
|     #     step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
| 
 | ||||
|     # use gradient accumulation | ||||
|     grads, step_metrics = accumulate_gradients_loop( | ||||
|         state=state, | ||||
|         batch=batch, | ||||
|         minibatch_size=32, | ||||
|         loss_fn=compute_loss | ||||
|     ) | ||||
|     new_state = state.apply_gradients(grads=grads) | ||||
| 
 | ||||
|     return new_state, step_metrics | ||||
| 
 | ||||
| # %% | ||||
| # explore data sharding | ||||
| sharded_batch = next(iter(train_loader)) | ||||
| sharded_batch = jax.device_put(sharded_batch, x_sharding) | ||||
| # jax.debug.visualize_array_sharding(sharded_batch['input_ids']) | ||||
| # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # # prep 1 step | ||||
| # print("1 step for jit-ting") | ||||
| # with mesh: | ||||
| #     state, metrics = train_step(initialized_state, sharded_batch) | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # %% | ||||
| # tr | ||||
| print("***** Running training *****") | ||||
| print(f"  Num examples = {training_size}") | ||||
| print(f"  Num Epochs = {num_epochs}") | ||||
| print(f"  Instantaneous batch size per device = {per_device_train_batch_size}") | ||||
| print(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}") | ||||
| print(f"  Total optimization steps = {total_train_steps}") | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # jax.profiler.start_trace("./traces") | ||||
| 
 | ||||
| 
 | ||||
| print("*" * 10) | ||||
| print("training start") | ||||
| rng, input_rng = jax.random.split(rng) | ||||
| train_time = 0 | ||||
| state = initialized_state | ||||
| epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) | ||||
| for epoch in epochs: | ||||
|     train_start = time.time() | ||||
| 
 | ||||
|     # Create sampling rng | ||||
|     train_metrics = [] | ||||
|     steps_per_epoch = training_size // train_batch_size | ||||
|     train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) | ||||
|     # Generate an epoch by shuffling sampling indices from the train dataset | ||||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|         batch = jax.device_put(batch, x_sharding) | ||||
|         with mesh: | ||||
|             state, train_metric = train_step(state, batch) | ||||
| 
 | ||||
|         # train_metrics.append(train_metric) | ||||
| 
 | ||||
| 
 | ||||
|     # this is for more accurate time stats, but slows down training | ||||
|     # train_metric['loss'].block_until_ready() | ||||
|     train_time = time.time() - train_start | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     epochs.write( | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} |  " | ||||
|         f"Loss: {train_metric['loss']}, " | ||||
|         f"Learning Rate:{train_metric['learning_rate']}, " | ||||
|         f"Last train time: {train_time})" | ||||
|     ) | ||||
| # jax.profiler.stop_trace() | ||||
| 
 | ||||
| # %% | ||||
| # try out | ||||
| gather_state = jax.device_get(state) | ||||
| gather_batch = jax.device_get(batch) | ||||
| logits = gather_state.apply_fn( | ||||
|     {'params': gather_state.params}, | ||||
|     input_ids=gather_batch['input_ids'], | ||||
|     attention_mask=gather_batch['attention_mask'], | ||||
|     decoder_input_ids=gather_batch['decoder_input_ids'], | ||||
|     decoder_attention_mask=gather_batch['decoder_attention_mask'], | ||||
| )[0]  # zero because output is some structure, where first is the logit | ||||
| 
 | ||||
| probs = nn.softmax(logits, axis=-1) | ||||
| predicted = jnp.argmax(probs, axis=-1) | ||||
| print("sample output") | ||||
| print(predicted[1]) | ||||
| 
 | ||||
| # %% | ||||
| main_model = custom_model.from_pretrained('t5-base') | ||||
| output_dir = save_path | ||||
| 
 | ||||
| # save checkpoint after each epoch and push checkpoint to the hub | ||||
| if jax.process_index() == 0: | ||||
|     params = jax.device_get(state.params) | ||||
|     main_model.save_pretrained(output_dir, params=params) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
|  | @ -0,0 +1,550 @@ | |||
| # %% | ||||
| import os | ||||
| # Set this to True to run the model on CPU only. | ||||
| USE_CPU_ONLY = False | ||||
| 
 | ||||
| flags = os.environ.get("XLA_FLAGS", "") | ||||
| if USE_CPU_ONLY: | ||||
|     flags += " --xla_force_host_platform_device_count=4"  # Simulate 8 devices | ||||
|     # Enforce CPU-only execution | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||||
|     os.environ["JAX_PLATFORMS"] = "cpu" | ||||
| else: | ||||
|     # GPU flags | ||||
|     flags = ( | ||||
|         '--xla_gpu_enable_triton_softmax_fusion=true ' | ||||
|         '--xla_gpu_triton_gemm_any=True ' | ||||
|         # '--xla_gpu_enable_async_collectives=true ' | ||||
|         '--xla_gpu_enable_latency_hiding_scheduler=true ' | ||||
|         '--xla_gpu_enable_highest_priority_async_stream=true ' | ||||
|     ) | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | ||||
| 
 | ||||
| os.environ["XLA_FLAGS"] = flags | ||||
| os.environ.update({ | ||||
|     "TOKENIZERS_PARALLELISM" : "false", | ||||
|     "CUDA_DEVICE_MAX_CONNECTIONS" : "1", | ||||
|     "NCCL_LL128_BUFFSIZE": "-2", | ||||
|     "NCCL_LL_BUFFSIZE": "-2", | ||||
|     "NCCL_PROTO": "SIMPLE,LL,LL128", | ||||
|     "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.5", | ||||
|     # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" | ||||
|  }) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| import functools | ||||
| from functools import partial | ||||
| from pprint import pprint | ||||
| from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union | ||||
| 
 | ||||
| import flax.linen as nn | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import numpy as np | ||||
| from jax.experimental.shard_map import shard_map | ||||
| from jax.sharding import Mesh, NamedSharding | ||||
| # from jax.experimental.pjit import pjit  # superseded by jax.jit | ||||
| from jax.experimental import mesh_utils | ||||
| from jax.sharding import PartitionSpec | ||||
| from ml_collections import ConfigDict | ||||
| import optax | ||||
| import logging | ||||
| import time | ||||
| from datasets import Dataset, load_from_disk | ||||
| 
 | ||||
| from flax import jax_utils, traverse_util | ||||
| from flax.training import train_state | ||||
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | ||||
| from flax.core.frozen_dict import freeze, unfreeze, FrozenDict | ||||
| import flax.core | ||||
| 
 | ||||
| # model checkpointing and saving utilities | ||||
| from flax import linen as nn | ||||
| from flax.training import checkpoints, train_state | ||||
| from flax import struct, serialization | ||||
| 
 | ||||
| from parallel.partitions import set_partitions | ||||
| 
 | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| from parallel.dataload import DataPrepare | ||||
| 
 | ||||
| # for memory tracking | ||||
| # from jax_smi import initialise_tracking | ||||
| # initialise_tracking() | ||||
| 
 | ||||
| 
 | ||||
| PyTree = Any | ||||
| Metrics = Dict[str, Tuple[jax.Array, ...]] | ||||
| 
 | ||||
| if USE_CPU_ONLY: | ||||
|     jax.config.update('jax_platform_name', 'cpu') | ||||
| else: | ||||
|     jax.config.update("jax_default_matmul_precision", "bfloat16") | ||||
| 
 | ||||
| jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") | ||||
| jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) | ||||
| jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| ## get platform type | ||||
| from jax.extend.backend import get_backend | ||||
| print(get_backend().platform) | ||||
| print(jax.devices()) | ||||
|   | ||||
| # %% | ||||
| # config options | ||||
| file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' | ||||
| save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/shmap/' | ||||
| # file_path = 'combined_data' | ||||
| split_datasets = load_from_disk(file_path) | ||||
| training_size = len(split_datasets['train']) | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 40 | ||||
| batch_size = 32  # do not go beyond 128, 64 is good | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
| per_device_eval_batch_size = batch_size | ||||
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||||
| steps_per_epoch = training_size // train_batch_size | ||||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 2e-5 | ||||
| 
 | ||||
| weight_decay = 0.01 | ||||
| adam_beta1 = 0.9 | ||||
| adam_beta2 = 0.999 | ||||
| adam_epsilon = 1e-8 | ||||
| label_smoothing_factor = 0.0 | ||||
| num_beams = 1 | ||||
| val_max_target_length = 128 | ||||
| predict_with_generate = True | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # prepare data | ||||
| # init object | ||||
| # e.g. Config | ||||
| print("preparing data") | ||||
| data_config = ConfigDict( | ||||
|     dict( | ||||
|         max_length=128, | ||||
|         pad_token_id=0, | ||||
|         decoder_start_token_id=0 | ||||
|     ) | ||||
| ) | ||||
| 
 | ||||
| dataprep = DataPrepare(split_datasets['train'], data_config) | ||||
| # # example usage | ||||
| # # %% | ||||
| seed = 117 | ||||
| rng = jax.random.PRNGKey(seed) | ||||
| train_loader = dataprep.data_loader(rng, batch_size=batch_size) | ||||
| batch = next(iter(train_loader)) | ||||
| 
 | ||||
| # %% | ||||
| # from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom | ||||
| from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model | ||||
| main_model = custom_model.from_pretrained( | ||||
|     "t5-base", | ||||
|     dtype=jnp.bfloat16, | ||||
|     gradient_checkpointing=True,  | ||||
| ) | ||||
| params = main_model.params | ||||
| # pretrained_params = model.params | ||||
| model = main_model.module | ||||
| 
 | ||||
| # %% | ||||
| # # testing config hashability | ||||
| # # some explanation: | ||||
| # # The PreTrainedModel class loads a T5Config model that is not hashable because | ||||
| # # it is a complicated class that pretends to be a dataclass. | ||||
| # # The solution is to extract a dict from it, then make a ConfigDict from | ||||
| # # ml_collections library so that we can get values via the "." operator. | ||||
| # # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to | ||||
| # # modify the config before passing to the next layer | ||||
| # from transformers import T5Config | ||||
| # from t5_model.configuration_t5 import FrozenT5Config | ||||
| # from ml_collections import ConfigDict, FrozenConfigDict | ||||
| #  | ||||
| # config = T5Config.from_pretrained("t5-base").to_dict() | ||||
| # config.pop('architectures') | ||||
| # config.pop('id2label') | ||||
| # # test if it works  | ||||
| # frozen_config = FrozenConfigDict(config) | ||||
| # # test hash | ||||
| # hash(frozen_config) | ||||
| 
 | ||||
| # %% | ||||
| # # print model | ||||
| # rng, input_rng = jax.random.split(rng) | ||||
| # model.tabulate( | ||||
| #     input_rng, | ||||
| #     input_ids=batch['input_ids'], | ||||
| #     attention_mask=batch['attention_mask'], | ||||
| #     decoder_input_ids=batch['decoder_input_ids'], | ||||
| #     decoder_attention_mask=batch['decoder_attention_mask'], | ||||
| #     console_kwargs={"force_jupyter": True} | ||||
| # ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # create mesh | ||||
| print("creating mesh") | ||||
| device_mesh = mesh_utils.create_device_mesh((2,2)) | ||||
| print(device_mesh) | ||||
| 
 | ||||
| mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) | ||||
| print(mesh) | ||||
| 
 | ||||
| def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: | ||||
|     if USE_CPU_ONLY: | ||||
|         return NamedSharding(mesh, pspec, memory_kind="unpinned_host") | ||||
|     else: | ||||
|         # if gpu | ||||
|         return NamedSharding(mesh, pspec, memory_kind="device") | ||||
| 
 | ||||
| x_sharding = mesh_sharding(PartitionSpec('data', None))  # replicate across data axis | ||||
| model_sharding=mesh_sharding(PartitionSpec(None, 'model')) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # optimizers | ||||
| 
 | ||||
| def create_learning_rate_fn( | ||||
|     train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | ||||
| ) -> Callable[[int], jnp.ndarray]: | ||||
|     """Returns a linear warmup, linear_decay learning rate function.""" | ||||
|     steps_per_epoch = train_ds_size // train_batch_size | ||||
|     num_train_steps = steps_per_epoch * num_train_epochs | ||||
|     warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | ||||
|     decay_fn = optax.linear_schedule( | ||||
|         init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | ||||
|     ) | ||||
|     schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | ||||
|     return schedule_fn | ||||
| 
 | ||||
| 
 | ||||
| # Create learning rate schedule | ||||
| linear_decay_lr_schedule_fn = create_learning_rate_fn( | ||||
|     training_size, | ||||
|     train_batch_size, | ||||
|     num_train_epochs, | ||||
|     warmup_steps, | ||||
|     learning_rate, | ||||
| ) | ||||
| 
 | ||||
| # We use Optax's "masking" functionality to not apply weight decay | ||||
| # to bias and LayerNorm scale parameters. decay_mask_fn returns a | ||||
| # mask boolean with the same structure as the parameters. | ||||
| # The mask is True for parameters that should be decayed. | ||||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["final_layer_norm", "layer_norm"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|         for layer in flat_params.keys() | ||||
|         if layer_norm_name in "".join(layer).lower() | ||||
|     } | ||||
|     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | ||||
|     return traverse_util.unflatten_dict(flat_mask) | ||||
| 
 | ||||
| # create adam optimizer | ||||
| adamw = optax.adamw( | ||||
|     learning_rate=linear_decay_lr_schedule_fn, | ||||
|     b1=adam_beta1, | ||||
|     b2=adam_beta2, | ||||
|     eps=adam_epsilon, | ||||
|     weight_decay=weight_decay, | ||||
|     mask=decay_mask_fn, | ||||
| ) | ||||
| 
 | ||||
| # %% | ||||
| def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | ||||
|     """ | ||||
|     The label smoothing implementation is adapted from Flax's official example: | ||||
|     https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 | ||||
|     """ | ||||
|     vocab_size = logits.shape[-1] | ||||
|     confidence = 1.0 - label_smoothing_factor | ||||
|     low_confidence = (1.0 - confidence) / (vocab_size - 1) | ||||
|     normalizing_constant = -( | ||||
|         confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) | ||||
|     ) | ||||
|     soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) | ||||
|     logits = jnp.asarray(logits, dtype=jnp.float32) | ||||
|     logits = logits.astype(jnp.float32) | ||||
|     soft_labels = soft_labels.astype(jnp.float32) | ||||
|     loss = optax.softmax_cross_entropy(logits, soft_labels) | ||||
|     loss = loss - normalizing_constant | ||||
| 
 | ||||
|     # ignore padded tokens from loss | ||||
|     loss = loss * padding_mask | ||||
|     mean_loss = loss.mean() | ||||
|     # num_labels = padding_mask.mean() | ||||
|     return mean_loss # , num_labels | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| ################################################################ | ||||
| # old jit in_shardings method | ||||
| 
 | ||||
| # create init_fn to produce sharded state | ||||
| def init_fn(params, model, optimizer): | ||||
|     # do be careful with the model init | ||||
|     # imported models might have complicated init methods | ||||
| 
 | ||||
|     # mask = create_mask_for_layer_norm(params) | ||||
|     # override params with bfloat version | ||||
|     # params= cast_floating_to(params, jnp.bfloat16, mask) | ||||
| 
 | ||||
|     state = train_state.TrainState.create(  # Create a `TrainState`. | ||||
|         apply_fn=model.apply, | ||||
|         params=params, | ||||
|         tx=optimizer) | ||||
|     return state | ||||
| 
 | ||||
| 
 | ||||
| abstract_variables = jax.eval_shape( | ||||
|     functools.partial(init_fn, model=model, optimizer=adamw), params) | ||||
| 
 | ||||
| # jax.sharding: describes how a jax.Array is laid out across devices | ||||
| state_sharding = nn.get_sharding(abstract_variables, mesh) | ||||
| 
 | ||||
| from parallel.partitions import set_partitions | ||||
| # set_partitions freezes the params on return | ||||
| model_part_spec = set_partitions(unfreeze(params)) | ||||
| # p is already a partition spec | ||||
| model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec) | ||||
| 
 | ||||
| # get pspec for opt_state | ||||
| def get_opt_spec(x): | ||||
|     if isinstance(x, dict): | ||||
|         return unfreeze(model_named_sharding) | ||||
|     # return an empty partspec | ||||
|     return mesh_sharding((PartitionSpec())) | ||||
| 
 | ||||
| # this function replaces the empty model params spec with the 'model_named_shard' | ||||
| state_sharding = jax.tree.map( | ||||
|     get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| jit_init_fn = jax.jit( | ||||
|     init_fn, | ||||
|     static_argnames=('model', 'optimizer'),  # skip model and optimizer | ||||
|     in_shardings=mesh_sharding(PartitionSpec()),  # we don't shard params explicitly | ||||
|     out_shardings=state_sharding  # but returned initialized_state is sharded | ||||
| ) | ||||
| initialized_state = jit_init_fn(params, model, adamw) | ||||
| 
 | ||||
| # %% | ||||
| # train step | ||||
| def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): | ||||
|     """ | ||||
|     The label smoothing implementation is adapted from Flax's official example: | ||||
|     https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 | ||||
|     """ | ||||
|     vocab_size = logits.shape[-1] | ||||
|     confidence = 1.0 - label_smoothing_factor | ||||
|     low_confidence = (1.0 - confidence) / (vocab_size - 1) | ||||
|     normalizing_constant = -( | ||||
|         confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) | ||||
|     ) | ||||
|     soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) | ||||
|     logits = jnp.asarray(logits, dtype=jnp.float32) | ||||
|     logits = logits.astype(jnp.float32) | ||||
|     soft_labels = soft_labels.astype(jnp.float32) | ||||
|     loss = optax.softmax_cross_entropy(logits, soft_labels) | ||||
|     loss = loss - normalizing_constant | ||||
| 
 | ||||
|     # ignore padded tokens from loss | ||||
|     loss = loss * padding_mask | ||||
|     loss = loss.mean() | ||||
|     # num_labels = padding_mask.mean() | ||||
|     return loss # , num_labels | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # single device code annotated with jax.jit | ||||
| @functools.partial( | ||||
|     jax.jit, | ||||
|     # state is state_sharding initialized from init_fn | ||||
|     # x_sharding is data sharded explicitly later | ||||
|     in_shardings=(state_sharding, x_sharding), | ||||
|     out_shardings=(state_sharding, mesh_sharding(PartitionSpec())), | ||||
|     donate_argnames=('state'), | ||||
| ) | ||||
| def train_step(state, batch): | ||||
| 
 | ||||
|     label_smoothing_factor=0.0 | ||||
|     # dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) | ||||
|     # computes loss per shard  | ||||
|     def compute_loss(params, batch): | ||||
|         # check constraints | ||||
|         # frozen dict not allowed as sharding object | ||||
|         params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding)) | ||||
|         batch = jax.lax.with_sharding_constraint(batch, x_sharding) | ||||
|         logits = state.apply_fn( | ||||
|             {'params': params}, | ||||
|             input_ids=batch['input_ids'], | ||||
|             attention_mask=batch['attention_mask'], | ||||
|             decoder_input_ids=batch['decoder_input_ids'], | ||||
|             decoder_attention_mask=batch['decoder_attention_mask'], | ||||
|         )[0]  # zero because output is some structure, where first is the logit | ||||
| 
 | ||||
|         # logits sharding | ||||
|         # data, None, model | ||||
|         #  | ||||
|         print("logits") | ||||
|         jax.debug.inspect_array_sharding(logits, callback=print) | ||||
|         # use labels here | ||||
|         # loss, num_labels = loss_fn( | ||||
|         loss = loss_fn( | ||||
|             logits, | ||||
|             batch["labels"], | ||||
|             batch["decoder_attention_mask"], | ||||
|             label_smoothing_factor) | ||||
|         # loss sharding | ||||
|         # it gives PartitionSpec(), which implies a reduction already happened | ||||
|         print("loss") | ||||
|         jax.debug.inspect_array_sharding(loss, callback=print) | ||||
| 
 | ||||
|         return loss # , num_labels | ||||
| 
 | ||||
|     # compute gradients through computational graph | ||||
|     # allow values to pass through | ||||
|     grad_fn = jax.value_and_grad(compute_loss, has_aux=False) | ||||
|     batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch) | ||||
|     (loss), grads = grad_fn(state.params, batch) | ||||
|     # num_labels = jax.lax.psum(num_labels, "batch") | ||||
| 
 | ||||
|     # so far we have been operating from within each shard | ||||
|     # we need to sync gradients across devices | ||||
|     # we bring all gradients together onto a single device | ||||
|     # jax.debug.inspect_array_sharding(grads, callback=print) | ||||
|     grads = jax.lax.with_sharding_constraint(grads, mesh_sharding(PartitionSpec())) | ||||
|     # grads = jax.lax.with_sharding_constraint(grads, state_sharding) | ||||
|     # jax.debug.visualize_array_sharding(grad) | ||||
|     # jax.debug.inspect_array_sharding(grad, callback=print) | ||||
|     # check the output grad tree from mean | ||||
|     # print(jax.tree.map(jnp.shape, grad)) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     new_state = state.apply_gradients(grads=grads) | ||||
|     with jax.named_scope("sync_metrics"): | ||||
|         step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} | ||||
| 
 | ||||
|     return new_state, step_metrics | ||||
| 
 | ||||
| # %% | ||||
| # explore data sharding | ||||
| sharded_batch = next(iter(train_loader)) | ||||
| # sharded_batch = jax.device_put(sharded_batch, x_sharding) | ||||
| sharded_batch = jax.tree.map(lambda x: jax.lax.with_sharding_constraint(x, x_sharding), batch) | ||||
| jax.debug.visualize_array_sharding(sharded_batch['input_ids']) | ||||
| # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # # prep 1 step | ||||
| print("1 step for jit-ting") | ||||
| with mesh: | ||||
|     state, metrics = train_step(initialized_state, sharded_batch) | ||||
| 
 | ||||
| # %% | ||||
| 
 | ||||
| # %% | ||||
| # tr | ||||
| print("***** Running training *****") | ||||
| print(f"  Num examples = {training_size}") | ||||
| print(f"  Num Epochs = {num_epochs}") | ||||
| print(f"  Instantaneous batch size per device = {per_device_train_batch_size}") | ||||
| print(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}") | ||||
| print(f"  Total optimization steps = {total_train_steps}") | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # jax.profiler.start_trace("./traces") | ||||
| 
 | ||||
| 
 | ||||
| print("*" * 20) | ||||
| print("training start") | ||||
| rng, input_rng = jax.random.split(rng) | ||||
| train_time = 0 | ||||
| state = initialized_state | ||||
| epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) | ||||
| for epoch in epochs: | ||||
|     train_start = time.time() | ||||
| 
 | ||||
|     # Create sampling rng | ||||
|     train_metrics = [] | ||||
|     steps_per_epoch = training_size // train_batch_size | ||||
|     train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) | ||||
|     # Generate an epoch by shuffling sampling indices from the train dataset | ||||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|         batch = jax.device_put(batch, x_sharding) | ||||
|         with mesh: | ||||
|             state, train_metric = train_step(state, batch) | ||||
| 
 | ||||
|         # train_metrics.append(train_metric) | ||||
| 
 | ||||
| 
 | ||||
|     # this is for more accurate time stats, but slows down training | ||||
|     # train_metric['loss'].block_until_ready() | ||||
|     train_time = time.time() - train_start | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     epochs.write( | ||||
|         f"Epoch... ({epoch + 1}/{num_epochs} |  " | ||||
|         f"Loss: {train_metric['loss']}, " | ||||
|         f"Learning Rate:{train_metric['learning_rate']}, " | ||||
|         f"Last train time: {train_time})" | ||||
|     ) | ||||
| # jax.profiler.stop_trace() | ||||
| 
 | ||||
| # %% | ||||
| # try out | ||||
| # gather_state = jax.device_get(state) | ||||
| # gather_batch = jax.device_get(batch) | ||||
| # logits = gather_state.apply_fn( | ||||
| #     {'params': gather_state.params}, | ||||
| #     input_ids=gather_batch['input_ids'], | ||||
| #     attention_mask=gather_batch['attention_mask'], | ||||
| #     decoder_input_ids=gather_batch['decoder_input_ids'], | ||||
| #     decoder_attention_mask=gather_batch['decoder_attention_mask'], | ||||
| # )[0]  # zero because output is some structure, where first is the logit | ||||
| #  | ||||
| # probs = nn.softmax(logits, axis=-1) | ||||
| # predicted = jnp.argmax(probs, axis=-1) | ||||
| # print(predicted[0]) | ||||
| 
 | ||||
| # %% | ||||
| main_model = custom_model.from_pretrained('t5-base') | ||||
| output_dir = save_path | ||||
| 
 | ||||
| # save checkpoint after each epoch and push checkpoint to the hub | ||||
| if jax.process_index() == 0: | ||||
|     params = jax.device_get(state.params) | ||||
|     params = jax.tree.map(lambda x: x.astype(jnp.float32), params) | ||||
|     main_model.save_pretrained(output_dir, params=params) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
|  | @ -27,7 +27,7 @@ os.environ.update({ | |||
|     "NCCL_LL128_BUFFSIZE": "-2", | ||||
|     "NCCL_LL_BUFFSIZE": "-2", | ||||
|     "NCCL_PROTO": "SIMPLE,LL,LL128", | ||||
|     "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90", | ||||
|     "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80", | ||||
|     # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" | ||||
|  }) | ||||
| 
 | ||||
|  | @ -89,35 +89,50 @@ jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) | |||
| jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| ## get platform type | ||||
| from jax.extend.backend import get_backend | ||||
| print(get_backend().platform) | ||||
| print(jax.devices()) | ||||
| 
 | ||||
| # %% | ||||
| # create mesh | ||||
| print("creating mesh") | ||||
| device_mesh = mesh_utils.create_device_mesh((4,1)) | ||||
| print(device_mesh) | ||||
| 
 | ||||
| mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) | ||||
| print(mesh) | ||||
| 
 | ||||
| def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: | ||||
|     return NamedSharding(mesh, pspec, memory_kind="device") | ||||
| 
 | ||||
| x_sharding = mesh_sharding(PartitionSpec('data', None))  # replicate across data axis | ||||
| model_sharding=mesh_sharding(PartitionSpec(None, 'model')) | ||||
| 
 | ||||
| 
 | ||||
|   | ||||
| # %% | ||||
| # config options | ||||
| file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/' | ||||
| save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/' | ||||
| save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple/' | ||||
| # file_path = 'combined_data' | ||||
| split_datasets = load_from_disk(file_path) | ||||
| training_size = len(split_datasets['train']) | ||||
| # Store some constant | ||||
| seed = 117 | ||||
| num_epochs = 5 | ||||
| batch_size = 64 | ||||
| num_epochs = 40 | ||||
| batch_size = 128 | ||||
| num_train_epochs = num_epochs | ||||
| per_device_train_batch_size = batch_size | ||||
| train_batch_size = per_device_train_batch_size * jax.device_count() | ||||
| train_batch_size = per_device_train_batch_size * 2  | ||||
| per_device_eval_batch_size = batch_size | ||||
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | ||||
| eval_batch_size = per_device_eval_batch_size * 2 | ||||
| steps_per_epoch = training_size // train_batch_size | ||||
| total_train_steps = steps_per_epoch * num_epochs | ||||
| 
 | ||||
| warmup_steps = 0 | ||||
| learning_rate = 5e-5 | ||||
| learning_rate = 2e-3 | ||||
| 
 | ||||
| weight_decay = 0.01 | ||||
| adam_beta1 = 0.9 | ||||
|  | @ -197,21 +212,6 @@ model = main_model.module | |||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # create mesh | ||||
| print("creating mesh") | ||||
| device_mesh = mesh_utils.create_device_mesh((2,2)) | ||||
| print(device_mesh) | ||||
| 
 | ||||
| mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) | ||||
| print(mesh) | ||||
| 
 | ||||
| def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: | ||||
|     return NamedSharding(mesh, pspec, memory_kind="device") | ||||
| 
 | ||||
| x_sharding = mesh_sharding(PartitionSpec('data', None))  # replicate across data axis | ||||
| model_sharding=mesh_sharding(PartitionSpec(None, 'model')) | ||||
| 
 | ||||
| 
 | ||||
| # %% | ||||
| # optimizers | ||||
|  | @ -246,7 +246,7 @@ linear_decay_lr_schedule_fn = create_learning_rate_fn( | |||
| def decay_mask_fn(params): | ||||
|     flat_params = traverse_util.flatten_dict(params) | ||||
|     # find out all LayerNorm parameters | ||||
|     layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | ||||
|     layer_norm_candidates = ["final_layer_norm", "layer_norm"] | ||||
|     layer_norm_named_params = { | ||||
|         layer[-2:] | ||||
|         for layer_norm_name in layer_norm_candidates | ||||
|  | @ -322,9 +322,9 @@ def init_fn(params, model, optimizer): | |||
|     # do be careful with the model init | ||||
|     # imported models might have complicated init methods | ||||
| 
 | ||||
|     # mask = create_mask_for_layer_norm(params) | ||||
|     mask = create_mask_for_layer_norm(params) | ||||
|     # override params with bfloat version | ||||
|     # params= cast_floating_to(params, jnp.bfloat16, mask) | ||||
|     params= cast_floating_to(params, jnp.bfloat16, mask) | ||||
| 
 | ||||
|     state = train_state.TrainState.create(  # Create a `TrainState`. | ||||
|         apply_fn=model.apply, | ||||
|  | @ -449,8 +449,8 @@ def train_step(state, batch): | |||
| 
 | ||||
| # %% | ||||
| # explore data sharding | ||||
| sharded_batch = next(iter(train_loader)) | ||||
| sharded_batch = jax.device_put(sharded_batch, x_sharding) | ||||
| # sharded_batch = next(iter(train_loader)) | ||||
| # sharded_batch = jax.device_put(sharded_batch, x_sharding) | ||||
| # jax.debug.visualize_array_sharding(sharded_batch['input_ids']) | ||||
| # jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding']) | ||||
| 
 | ||||
|  | @ -477,7 +477,7 @@ print(f"  Total optimization steps = {total_train_steps}") | |||
| # jax.profiler.start_trace("./traces") | ||||
| 
 | ||||
| 
 | ||||
| print("*" * 10) | ||||
| print("*" * 50) | ||||
| print("training start") | ||||
| rng, input_rng = jax.random.split(rng) | ||||
| train_time = 0 | ||||
|  | @ -489,7 +489,7 @@ for epoch in epochs: | |||
|     # Create sampling rng | ||||
|     train_metrics = [] | ||||
|     steps_per_epoch = training_size // train_batch_size | ||||
|     train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True) | ||||
|     train_loader = dataprep.data_loader(rng, batch_size=train_batch_size, shuffle=True, drop_last=True) | ||||
|     # Generate an epoch by shuffling sampling indices from the train dataset | ||||
|     for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): | ||||
|         batch = next(train_loader) | ||||
|  | @ -514,21 +514,21 @@ for epoch in epochs: | |||
|     ) | ||||
| # jax.profiler.stop_trace() | ||||
| 
 | ||||
| # %% | ||||
| # try out | ||||
| gather_state = jax.device_get(state) | ||||
| gather_batch = jax.device_get(batch) | ||||
| logits = gather_state.apply_fn( | ||||
|     {'params': gather_state.params}, | ||||
|     input_ids=gather_batch['input_ids'], | ||||
|     attention_mask=gather_batch['attention_mask'], | ||||
|     decoder_input_ids=gather_batch['decoder_input_ids'], | ||||
|     decoder_attention_mask=gather_batch['decoder_attention_mask'], | ||||
| )[0]  # zero because output is some structure, where first is the logit | ||||
| 
 | ||||
| probs = nn.softmax(logits, axis=-1) | ||||
| predicted = jnp.argmax(probs, axis=-1) | ||||
| predicted[1] | ||||
| # # %% | ||||
| # # try out | ||||
| # gather_state = jax.device_get(state) | ||||
| # gather_batch = jax.device_get(batch) | ||||
| # logits = gather_state.apply_fn( | ||||
| #     {'params': gather_state.params}, | ||||
| #     input_ids=gather_batch['input_ids'], | ||||
| #     attention_mask=gather_batch['attention_mask'], | ||||
| #     decoder_input_ids=gather_batch['decoder_input_ids'], | ||||
| #     decoder_attention_mask=gather_batch['decoder_attention_mask'], | ||||
| # )[0]  # zero because output is some structure, where first is the logit | ||||
| #  | ||||
| # probs = nn.softmax(logits, axis=-1) | ||||
| # predicted = jnp.argmax(probs, axis=-1) | ||||
| # print(predicted[0]) | ||||
| 
 | ||||
| # %% | ||||
| main_model = custom_model.from_pretrained('t5-base') | ||||
|  | @ -537,6 +537,7 @@ output_dir = save_path | |||
| # save checkpoint after each epoch and push checkpoint to the hub | ||||
| if jax.process_index() == 0: | ||||
|     params = jax.device_get(state.params) | ||||
|     params = jax.tree.map(lambda x: x.astype(jnp.float32), params) | ||||
|     main_model.save_pretrained(output_dir, params=params) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1 @@ | |||
| __pycache__ | ||||
|  | @ -0,0 +1,118 @@ | |||
| # coding=utf-8 | ||||
| # Copyright 2020, The T5 Authors and HuggingFace Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """T5 model configuration""" | ||||
| 
 | ||||
| from typing import Mapping | ||||
| 
 | ||||
| from transformers import PretrainedConfig | ||||
| from transformers import logging | ||||
| 
 | ||||
| from etils import edc | ||||
| 
 | ||||
| logger = logging.get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class T5Config(PretrainedConfig): | ||||
|     model_type = "t5" | ||||
|     keys_to_ignore_at_inference = ["past_key_values"] | ||||
|     attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         vocab_size=32128, # vocab size here | ||||
|         d_model=512, | ||||
|         d_kv=64, | ||||
|         d_ff=2048, | ||||
|         num_layers=6, | ||||
|         num_decoder_layers=None, | ||||
|         num_heads=8, | ||||
|         relative_attention_num_buckets=32, | ||||
|         relative_attention_max_distance=128, | ||||
|         dropout_rate=0.1, | ||||
|         layer_norm_epsilon=1e-6, | ||||
|         initializer_factor=1.0, | ||||
|         feed_forward_proj="relu", | ||||
|         is_encoder_decoder=True, | ||||
|         use_cache=True, | ||||
|         pad_token_id=0, | ||||
|         eos_token_id=1, | ||||
|         classifier_dropout=0.0, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         self.vocab_size = vocab_size | ||||
|         self.d_model = d_model | ||||
|         self.d_kv = d_kv | ||||
|         self.d_ff = d_ff | ||||
|         self.num_layers = num_layers | ||||
|         self.num_decoder_layers = ( | ||||
|             num_decoder_layers if num_decoder_layers is not None else self.num_layers | ||||
|         )  # default = symmetry | ||||
|         self.num_heads = num_heads | ||||
|         self.relative_attention_num_buckets = relative_attention_num_buckets | ||||
|         self.relative_attention_max_distance = relative_attention_max_distance | ||||
|         self.dropout_rate = dropout_rate | ||||
|         self.classifier_dropout = classifier_dropout | ||||
|         self.layer_norm_epsilon = layer_norm_epsilon | ||||
|         self.initializer_factor = initializer_factor | ||||
|         self.feed_forward_proj = feed_forward_proj | ||||
|         self.use_cache = use_cache | ||||
|         self.use_bfloat16 = True | ||||
| 
 | ||||
|         act_info = self.feed_forward_proj.split("-") | ||||
|         self.dense_act_fn = act_info[-1] | ||||
|         self.is_gated_act = act_info[0] == "gated" | ||||
| 
 | ||||
|         if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: | ||||
|             raise ValueError( | ||||
|                 f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. " | ||||
|                 "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " | ||||
|                 "'gated-gelu' or 'relu'" | ||||
|             ) | ||||
| 
 | ||||
|         # for backwards compatibility | ||||
|         if feed_forward_proj == "gated-gelu": | ||||
|             self.dense_act_fn = "gelu_new" | ||||
| 
 | ||||
|         super().__init__( | ||||
|             pad_token_id=pad_token_id, | ||||
|             eos_token_id=eos_token_id, | ||||
|             is_encoder_decoder=is_encoder_decoder, | ||||
|             **kwargs, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| # class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): | ||||
| #     @property | ||||
| #     def inputs(self) -> Mapping[str, Mapping[int, str]]: | ||||
| #         common_inputs = { | ||||
| #             "input_ids": {0: "batch", 1: "encoder_sequence"}, | ||||
| #             "attention_mask": {0: "batch", 1: "encoder_sequence"}, | ||||
| #         } | ||||
| #         if self.use_past: | ||||
| #             common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" | ||||
| #             common_inputs["decoder_input_ids"] = {0: "batch"} | ||||
| #             common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} | ||||
| #         else: | ||||
| #             common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} | ||||
| #             common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} | ||||
| #  | ||||
| #         if self.use_past: | ||||
| #             self.fill_with_past_key_values_(common_inputs, direction="inputs") | ||||
| #  | ||||
| #         return common_inputs | ||||
| #  | ||||
| #     @property | ||||
| #     def default_onnx_opset(self) -> int: | ||||
| #         return 13 | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue