# %% import os # Set this to True to run the model on CPU only. USE_CPU_ONLY = False flags = os.environ.get("XLA_FLAGS", "") if USE_CPU_ONLY: flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices # Enforce CPU-only execution os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["JAX_PLATFORMS"] = "cpu" else: # GPU flags flags = ( '--xla_gpu_enable_triton_softmax_fusion=true ' '--xla_gpu_triton_gemm_any=True ' # '--xla_gpu_enable_async_collectives=true ' '--xla_gpu_enable_latency_hiding_scheduler=true ' '--xla_gpu_enable_highest_priority_async_stream=true ' ) os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" os.environ["XLA_FLAGS"] = flags os.environ.update({ "TOKENIZERS_PARALLELISM" : "false", "CUDA_DEVICE_MAX_CONNECTIONS" : "1", "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", "XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80", # "XLA_PYTHON_CLIENT_PREALLOCATE" : "false" }) import pandas as pd import matplotlib.pyplot as plt from datasets import Dataset import jax import jax.numpy as jnp import optax import numpy as np import functools from typing import Callable, Optional import math from jax.sharding import Mesh, NamedSharding from jax.experimental import mesh_utils from jax.sharding import PartitionSpec # set cache jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) # jax.config.update("jax_default_matmul_precision", "tensorfloat32") jax.config.update("jax_default_matmul_precision", "bfloat16") jax.config.update("jax_enable_x64", False) # from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig import datasets from datasets import Dataset import evaluate from tqdm import tqdm import nltk # Here to have a nice missing dependency error message early on from flax import jax_utils, traverse_util from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from ml_collections import ConfigDict import time from parallel.dataload import DataPrepare # %% # import data # load training data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv" # Ensure to include 'ships_idx' in the fields list fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] # Load the dataset train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) # # load valid # data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv" # # Ensure to include 'ships_idx' in the fields list # fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit'] # # Load the dataset # validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields) def process_df(df): output_list = [{ 'input': f"{row['tag_name']}{row['tag_description']}", 'output': f"{row['thing']}{row['property']}", } for _, row in df.iterrows()] return output_list # takes 1 minute to run without batching train_dataset = Dataset.from_list(process_df(train_df)) print("preparing data") data_config = ConfigDict( dict( max_length=128, pad_token_id=0, decoder_start_token_id=0 ) ) dataprep = DataPrepare(train_dataset, data_config) # %% # load model model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name from transformers import FlaxT5ForConditionalGeneration model = FlaxT5ForConditionalGeneration.from_pretrained( model_name_or_path ) params = model.params # %% # load data SEED = 117 batch_size = 256 # per device batch_size # test_batch_size multiplies by 4 because we shard by 4 later train_batch_size = batch_size * jax.device_count() rng = jax.random.PRNGKey(SEED) # %% # setup sharding print("creating mesh") device_mesh = mesh_utils.create_device_mesh((4,1)) print(device_mesh) mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) print(mesh) def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec, memory_kind="device") data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis # model_sharding=mesh_sharding(PartitionSpec('model')) replicate_sharding=mesh_sharding(PartitionSpec()) # %% # define function to get encodings def get_encodings(batch, params): input_ids=batch['input_ids'] attention_mask=batch['attention_mask'] input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1])) attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1])) encoder_outputs = model.encode( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False, train=False, params=params, dropout_rng=None ) # encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed) # the embedding is not the full embedding size, but the self-attention embed # size # parallelize by multiply encoder outputs with attention mask # shape (batch, embed) -> (batch, embed, 1) # this helps it to have the same shape as encoder_outputs expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1) # here is an element-wise multiply embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768) # summing embeddings in axis 1 will sum column-wise into a (batch, 768) # summing attention_mask in axis 1 will sum column-wise to get the total # unmasked token count for data entry mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1) # the shape of mean_embeddings is (batch, embed), we are ready to return return mean_embeddings get_encodings_jit = jax.jit( functools.partial(get_encodings, params=params), # rng, batch in_shardings=(data_sharding), out_shardings=replicate_sharding, ) # # %% # # test the get_encodings function # # rng, input_rng = jax.random.split(rng) # train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False) # batch = next(train_loader) # encodings = get_encodings(batch, params) # # function works! # %% # perform actual prediction encoding_list = [] # note: train_batch_size is batch_size * 4 # we have 4 devices pred_steps = math.ceil(len(train_dataset) / train_batch_size) print("***** Running prediction *****") print(f" Num examples = {len(train_dataset)}") print(f" Num steps = {pred_steps}") print(f" Instantaneous batch size per device = {batch_size}") print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}") rng, input_rng = jax.random.split(rng) train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False) for _ in tqdm(range(pred_steps), desc="Predicting..."): batch = next(train_loader) batch = jax.device_put(batch, data_sharding) encodings = get_encodings_jit(batch) encoding_list.extend(jax.device_get(encodings)) # %% encoding_list = jnp.vstack(encoding_list) # slice up to the previously defined list to unpad encoding_list = encoding_list[:len(train_dataset)] print(encoding_list.shape) # %% # getting top-k def top_k_cosine_similarity(M, a, k): """ Find the top-k rows in matrix M that are most cosine similar to array a. Args: M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector. a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M. k (int): Number of top cosine similarities to retrieve. Returns: values (jnp.ndarray): Top-k cosine similarity values. indices (jnp.ndarray): Indices of the top-k most similar rows in M. """ # Step 1: Normalize both M and a M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d) a_norm = a / jnp.linalg.norm(a) # Shape: (d,) # Step 2: Compute cosine similarity via dot product cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,) # Step 3: Get the top-k values and their indices using jax.lax.top_k values, indices = jax.lax.top_k(cosine_similarities, k) return values, indices # Example usage: M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions # Find top 5 most similar rows top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5) print("Top-k cosine similarity values:", top_k_values) print("Indices of top-k similar rows:", top_k_indices) # %%