280 lines
8.9 KiB
Python
280 lines
8.9 KiB
Python
# %%
|
|
import os
|
|
|
|
# Set this to True to run the model on CPU only.
|
|
USE_CPU_ONLY = False
|
|
|
|
flags = os.environ.get("XLA_FLAGS", "")
|
|
if USE_CPU_ONLY:
|
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
|
# Enforce CPU-only execution
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
else:
|
|
# GPU flags
|
|
flags = (
|
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
|
'--xla_gpu_triton_gemm_any=True '
|
|
# '--xla_gpu_enable_async_collectives=true '
|
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
|
)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
|
|
|
os.environ["XLA_FLAGS"] = flags
|
|
os.environ.update({
|
|
"TOKENIZERS_PARALLELISM" : "false",
|
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
|
"NCCL_LL128_BUFFSIZE": "-2",
|
|
"NCCL_LL_BUFFSIZE": "-2",
|
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.80",
|
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
|
})
|
|
|
|
|
|
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
from datasets import Dataset
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
import numpy as np
|
|
import functools
|
|
from typing import Callable, Optional
|
|
import math
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.experimental import mesh_utils
|
|
|
|
from jax.sharding import PartitionSpec
|
|
|
|
# set cache
|
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
|
|
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
|
|
|
jax.config.update("jax_enable_x64", False)
|
|
|
|
|
|
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
|
|
|
|
|
import datasets
|
|
from datasets import Dataset
|
|
import evaluate
|
|
from tqdm import tqdm
|
|
|
|
|
|
import nltk # Here to have a nice missing dependency error message early on
|
|
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
|
|
from ml_collections import ConfigDict
|
|
|
|
import time
|
|
|
|
from parallel.dataload import DataPrepare
|
|
|
|
|
|
# %%
|
|
# import data
|
|
|
|
# load training
|
|
data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv"
|
|
# Ensure to include 'ships_idx' in the fields list
|
|
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
|
# Load the dataset
|
|
train_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
|
|
|
# # load valid
|
|
# data_path = "/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/valid.csv"
|
|
# # Ensure to include 'ships_idx' in the fields list
|
|
# fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
|
# # Load the dataset
|
|
# validation_df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
|
def process_df(df):
|
|
output_list = [{
|
|
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
|
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
|
} for _, row in df.iterrows()]
|
|
return output_list
|
|
|
|
|
|
# takes 1 minute to run without batching
|
|
train_dataset = Dataset.from_list(process_df(train_df))
|
|
|
|
|
|
print("preparing data")
|
|
data_config = ConfigDict(
|
|
dict(
|
|
max_length=128,
|
|
pad_token_id=0,
|
|
decoder_start_token_id=0
|
|
)
|
|
)
|
|
|
|
dataprep = DataPrepare(train_dataset, data_config)
|
|
|
|
# %%
|
|
# load model
|
|
model_name_or_path = "./model_checkpoints/simple" # Replace with your specific model name
|
|
from transformers import FlaxT5ForConditionalGeneration
|
|
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
|
model_name_or_path
|
|
)
|
|
params = model.params
|
|
|
|
# %%
|
|
# load data
|
|
SEED = 117
|
|
batch_size = 256 # per device batch_size
|
|
# test_batch_size multiplies by 4 because we shard by 4 later
|
|
train_batch_size = batch_size * jax.device_count()
|
|
rng = jax.random.PRNGKey(SEED)
|
|
|
|
|
|
# %%
|
|
# setup sharding
|
|
print("creating mesh")
|
|
device_mesh = mesh_utils.create_device_mesh((4,1))
|
|
print(device_mesh)
|
|
|
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
|
print(mesh)
|
|
|
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
|
|
|
data_sharding = mesh_sharding(PartitionSpec('data')) # replicate across data axis
|
|
# model_sharding=mesh_sharding(PartitionSpec('model'))
|
|
replicate_sharding=mesh_sharding(PartitionSpec())
|
|
|
|
|
|
# %%
|
|
# define function to get encodings
|
|
|
|
def get_encodings(batch, params):
|
|
input_ids=batch['input_ids']
|
|
attention_mask=batch['attention_mask']
|
|
input_ids = jnp.reshape(input_ids, (input_ids.shape[-2], input_ids.shape[-1]))
|
|
attention_mask = jnp.reshape(attention_mask, (input_ids.shape[-2], input_ids.shape[-1]))
|
|
encoder_outputs = model.encode(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
train=False,
|
|
params=params,
|
|
dropout_rng=None
|
|
)
|
|
# encoder_outputs gives 'last_hidden_state' of shape: (batch, seq_len, embed)
|
|
# the embedding is not the full embedding size, but the self-attention embed
|
|
# size
|
|
|
|
# parallelize by multiply encoder outputs with attention mask
|
|
# shape (batch, embed) -> (batch, embed, 1)
|
|
# this helps it to have the same shape as encoder_outputs
|
|
expanded_attention_mask = jnp.expand_dims(attention_mask, 2) # (batch, 128, 1)
|
|
# here is an element-wise multiply
|
|
embeddings = encoder_outputs['last_hidden_state'] * expanded_attention_mask # (batch, 128, 768)
|
|
# summing embeddings in axis 1 will sum column-wise into a (batch, 768)
|
|
# summing attention_mask in axis 1 will sum column-wise to get the total
|
|
# unmasked token count for data entry
|
|
mean_embeddings = (embeddings).sum(axis=1) / expanded_attention_mask.sum(axis=1)
|
|
# the shape of mean_embeddings is (batch, embed), we are ready to return
|
|
return mean_embeddings
|
|
|
|
get_encodings_jit = jax.jit(
|
|
functools.partial(get_encodings, params=params),
|
|
# rng, batch
|
|
in_shardings=(data_sharding),
|
|
out_shardings=replicate_sharding,
|
|
)
|
|
# # %%
|
|
# # test the get_encodings function
|
|
#
|
|
# rng, input_rng = jax.random.split(rng)
|
|
# train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
|
|
# batch = next(train_loader)
|
|
# encodings = get_encodings(batch, params)
|
|
# # function works!
|
|
|
|
|
|
# %%
|
|
# perform actual prediction
|
|
encoding_list = []
|
|
# note: train_batch_size is batch_size * 4
|
|
# we have 4 devices
|
|
pred_steps = math.ceil(len(train_dataset) / train_batch_size)
|
|
print("***** Running prediction *****")
|
|
print(f" Num examples = {len(train_dataset)}")
|
|
print(f" Num steps = {pred_steps}")
|
|
print(f" Instantaneous batch size per device = {batch_size}")
|
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
|
|
|
rng, input_rng = jax.random.split(rng)
|
|
train_loader = dataprep.data_loader(input_rng, batch_size=train_batch_size, shuffle=False, drop_last=False)
|
|
|
|
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
|
batch = next(train_loader)
|
|
batch = jax.device_put(batch, data_sharding)
|
|
encodings = get_encodings_jit(batch)
|
|
encoding_list.extend(jax.device_get(encodings))
|
|
|
|
|
|
# %%
|
|
encoding_list = jnp.vstack(encoding_list)
|
|
# slice up to the previously defined list to unpad
|
|
encoding_list = encoding_list[:len(train_dataset)]
|
|
print(encoding_list.shape)
|
|
|
|
|
|
|
|
# %%
|
|
# getting top-k
|
|
def top_k_cosine_similarity(M, a, k):
|
|
"""
|
|
Find the top-k rows in matrix M that are most cosine similar to array a.
|
|
|
|
Args:
|
|
M (jnp.ndarray): Matrix of shape (n, d), where each row is a d-dimensional vector.
|
|
a (jnp.ndarray): Array of shape (d,), the vector to compare to each row of M.
|
|
k (int): Number of top cosine similarities to retrieve.
|
|
|
|
Returns:
|
|
values (jnp.ndarray): Top-k cosine similarity values.
|
|
indices (jnp.ndarray): Indices of the top-k most similar rows in M.
|
|
"""
|
|
# Step 1: Normalize both M and a
|
|
M_norm = M / jnp.linalg.norm(M, axis=1, keepdims=True) # Shape: (n, d)
|
|
a_norm = a / jnp.linalg.norm(a) # Shape: (d,)
|
|
|
|
# Step 2: Compute cosine similarity via dot product
|
|
cosine_similarities = jnp.dot(M_norm, a_norm) # Shape: (n,)
|
|
|
|
# Step 3: Get the top-k values and their indices using jax.lax.top_k
|
|
values, indices = jax.lax.top_k(cosine_similarities, k)
|
|
|
|
return values, indices
|
|
|
|
# Example usage:
|
|
M = jax.random.normal(jax.random.PRNGKey(0), (100, 128)) # Random matrix with 100 rows of 128 dimensions
|
|
a = jax.random.normal(jax.random.PRNGKey(1), (128,)) # Random query vector of 128 dimensions
|
|
|
|
# Find top 5 most similar rows
|
|
top_k_values, top_k_indices = top_k_cosine_similarity(M, a, k=5)
|
|
|
|
print("Top-k cosine similarity values:", top_k_values)
|
|
print("Indices of top-k similar rows:", top_k_indices)
|
|
|
|
# %%
|