learn_jax/make_context_data.py

280 lines
8.9 KiB
Python

# %%
import os
# Set this to True to run the model on CPU only.
USE_CPU_ONLY = False
flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# GPU flags
flags = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
# '--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_FLAGS"] = flags
os.environ.update({
"TOKENIZERS_PARALLELISM" : "false",
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
})
import pandas as pd
import matplotlib.pyplot as plt
from datasets import Dataset
import jax
import jax.numpy as jnp
import optax
import numpy as np
import functools
from typing import Callable, Optional
import math
from jax.sharding import Mesh, NamedSharding
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec
# set cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
jax.config.update("jax_default_matmul_precision", "bfloat16")
jax.config.update("jax_enable_x64", False)
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
import datasets
from datasets import Dataset
import evaluate
from tqdm import tqdm
import nltk # Here to have a nice missing dependency error message early on
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ml_collections import ConfigDict
import time
from parallel.dataload import DataPrepare
# %%
# import data
# load training
data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv"
# Ensure to include 'ships_idx' in the fields list
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# Load the dataset
train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
# # load valid
# data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv"
# # Ensure to include 'ships_idx' in the fields list
# fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
# # Load the dataset
# validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
def process_df(df):
output_list = [{
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
} for _, row in df.iterrows()]
return output_list
# takes 1 minute to run without batching
train_dataset = Dataset.from_list(process_df(train_df))
print("preparing data")
data_config = ConfigDict(
dict(
max_length=128,
pad_token_id=0,
decoder_start_token_id=0
)
)
dataprep = DataPrepare(train_dataset, data_config)
# %%
# load model
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
from transformers import FlaxT5ForConditionalGeneration
model = FlaxT5ForConditionalGeneration.from_pretrained(
model_name_or_path
)
params = model.params
# %%
# load data
SEED = 117
batch_size = 256 # per device batch_size
# test_batch_size multiplies by 4 because we shard by 4 later
train_batch_size = batch_size * jax.device_count()
rng = jax.random.PRNGKey(SEED)
# %%
# setup sharding
print("creating mesh")
device_mesh = mesh_utils.create_device_mesh((4,1))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
return NamedSharding(mesh, pspec, memory_kind="device")
data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis
# model_sharding=mesh_sharding(PartitionSpec('model'))
replicate_sharding=mesh_sharding(PartitionSpec())
# %%
# define function to get encodings
def get_encodings(batch, params):
input_ids=batch['input_ids']
attention_mask=batch['attention_mask']
input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1]))
attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1]))
encoder_outputs = model.encode(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
train=False,
params=params,
dropout_rng=None
)
# encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed)
# the embedding is not the full embedding size, but the self-attention embed
# size
# parallelize by multiply encoder outputs with attention mask
# shape (batch, embed) -> (batch, embed, 1)
# this helps it to have the same shape as encoder_outputs
expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1)
# here is an element-wise multiply
embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768)
# summing embeddings in axis 1 will sum column-wise into a (batch, 768)
# summing attention_mask in axis 1 will sum column-wise to get the total
# unmasked token count for data entry
mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1)
# the shape of mean_embeddings is (batch, embed), we are ready to return
return mean_embeddings
get_encodings_jit = jax.jit(
functools.partial(get_encodings, params=params),
# rng, batch
in_shardings=(data_sharding),
out_shardings=replicate_sharding,
)
# # %%
# # test the get_encodings function
#
# rng, input_rng = jax.random.split(rng)
# train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
# batch = next(train_loader)
# encodings = get_encodings(batch, params)
# # function works!
# %%
# perform actual prediction
encoding_list = []
# note: train_batch_size is batch_size * 4
# we have 4 devices
pred_steps = math.ceil(len(train_dataset) / train_batch_size)
print("***** Running prediction *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num steps = {pred_steps}")
print(f" Instantaneous batch size per device = {batch_size}")
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
rng, input_rng = jax.random.split(rng)
train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
for _ in tqdm(range(pred_steps), desc="Predicting..."):
batch = next(train_loader)
batch = jax.device_put(batch, data_sharding)
encodings = get_encodings_jit(batch)
encoding_list.extend(jax.device_get(encodings))
# %%
encoding_list = jnp.vstack(encoding_list)
# slice up to the previously defined list to unpad
encoding_list = encoding_list[:len(train_dataset)]
print(encoding_list.shape)
# %%
# getting top-k
def top_k_cosine_similarity(M, a, k):
"""
Find the top-k rows in matrix M that are most cosine similar to array a.
Args:
M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector.
a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M.
k (int): Number of top cosine similarities to retrieve.
Returns:
values (jnp.ndarray): Top-k cosine similarity values.
indices (jnp.ndarray): Indices of the top-k most similar rows in M.
"""
# Step 1: Normalize both M and a
M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d)
a_norm = a / jnp.linalg.norm(a) # Shape: (d,)
# Step 2: Compute cosine similarity via dot product
cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,)
# Step 3: Get the top-k values and their indices using jax.lax.top_k
values, indices = jax.lax.top_k(cosine_similarities, k)
return values, indices
# Example usage:
M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions
a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions
# Find top 5 most similar rows
top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5)
print("Top-k cosine similarity values:", top_k_values)
print("Indices of top-k similar rows:", top_k_indices)
# %%