Feat: t5_jax_simple_parallel implements a working example of fsdp
This commit is contained in:
parent
429e1742ab
commit
aca80720c8
|
@ -1,5 +1,6 @@
|
||||||
*.ipynb
|
*.ipynb
|
||||||
t5_*/
|
t5_*/
|
||||||
|
model_checkpoints/
|
||||||
exports/
|
exports/
|
||||||
modified_t5_model/
|
modified_t5_model/
|
||||||
traces/
|
traces/
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
__pycache__
|
__pycache__
|
||||||
|
gpt-neo-125m/
|
||||||
|
|
|
@ -17,17 +17,6 @@ file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retri
|
||||||
# training_size = len(split_datasets['train'])
|
# training_size = len(split_datasets['train'])
|
||||||
|
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
|
||||||
# Define additional special tokens
|
|
||||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
||||||
# Add the additional special tokens to the tokenizer
|
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
||||||
|
|
||||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
|
||||||
|
|
||||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
||||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
|
||||||
|
|
||||||
|
|
||||||
# class takes in a dataset
|
# class takes in a dataset
|
||||||
class DataPrepare():
|
class DataPrepare():
|
||||||
|
@ -37,6 +26,19 @@ class DataPrepare():
|
||||||
self.train_dataset: Optional[Dataset] = None
|
self.train_dataset: Optional[Dataset] = None
|
||||||
self.size: int = len(raw_dataset)
|
self.size: int = len(raw_dataset)
|
||||||
self.config: ConfigDict = config
|
self.config: ConfigDict = config
|
||||||
|
self.tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
|
||||||
|
# Define additional special tokens
|
||||||
|
# additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||||
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||||
|
# Add the additional special tokens to the tokenizer
|
||||||
|
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
|
self.shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # noqa: B009
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.make_dataset()
|
self.make_dataset()
|
||||||
|
|
||||||
|
@ -52,31 +54,37 @@ class DataPrepare():
|
||||||
# text_target sets the corresponding label to inputs
|
# text_target sets the corresponding label to inputs
|
||||||
# there is no need to create a separate 'labels'
|
# there is no need to create a separate 'labels'
|
||||||
# produce input_ids and decoder_input_ids
|
# produce input_ids and decoder_input_ids
|
||||||
model_inputs = tokenizer(
|
model_inputs = self.tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
max_length=self.config.max_length,
|
max_length=self.config.max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="np"
|
return_tensors="np"
|
||||||
)
|
)
|
||||||
labels = tokenizer(
|
# we separate it out because we need the attention mask
|
||||||
|
labels = self.tokenizer(
|
||||||
text_target=targets,
|
text_target=targets,
|
||||||
max_length=self.config.max_length,
|
max_length=self.config.max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="np"
|
return_tensors="np"
|
||||||
)
|
)
|
||||||
|
model_inputs['input_ids'] = np.asarray(model_inputs['input_ids'])
|
||||||
|
model_inputs['attention_mask'] = np.asarray(model_inputs['attention_mask'])
|
||||||
# for loss computation
|
# for loss computation
|
||||||
model_inputs["labels"] = labels["input_ids"]
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
# make decoder input ids
|
# make decoder input ids
|
||||||
decoder_input_ids = shift_tokens_right_fn(
|
# this is actually "model output" shifted right
|
||||||
|
decoder_input_ids = self.shift_tokens_right_fn(
|
||||||
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
labels["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
# require by model
|
# require by model
|
||||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||||
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
# decoder_attention_mask = shift_tokens_right_fn(
|
||||||
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
# labels["attention_mask"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
# )
|
||||||
|
# We need decoder_attention_mask so we can ignore pad tokens in loss
|
||||||
|
model_inputs["decoder_attention_mask"] = np.asarray(labels["attention_mask"])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
@ -89,13 +97,13 @@ class DataPrepare():
|
||||||
remove_columns=self.raw_dataset.column_names,)
|
remove_columns=self.raw_dataset.column_names,)
|
||||||
|
|
||||||
# set to numpy
|
# set to numpy
|
||||||
train_dataset.set_format(
|
# train_dataset.set_format(
|
||||||
type='numpy',
|
# type='numpy',
|
||||||
columns=[
|
# columns=[
|
||||||
'input_ids', 'attention_mask',
|
# 'input_ids', 'attention_mask', 'labels',
|
||||||
'labels', 'decoder_input_ids',
|
# 'decoder_input_ids',
|
||||||
'decoder_attention_mask']
|
# 'decoder_attention_mask']
|
||||||
)
|
# )
|
||||||
|
|
||||||
# check that data fits
|
# check that data fits
|
||||||
# for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
# for name in ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'decoder_attention_mask']:
|
||||||
|
@ -140,15 +148,15 @@ class DataPrepare():
|
||||||
|
|
||||||
for idx in batch_idx:
|
for idx in batch_idx:
|
||||||
batch = dataset[idx]
|
batch = dataset[idx]
|
||||||
batch = {k: v for k, v in batch.items()}
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
|
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
# testing out the class
|
# # testing out the class
|
||||||
# %%
|
# # %%
|
||||||
# init object
|
# # init object
|
||||||
# e.g. Config
|
# # e.g. Config
|
||||||
# data_config = ConfigDict(
|
# data_config = ConfigDict(
|
||||||
# dict(
|
# dict(
|
||||||
# max_length=86,
|
# max_length=86,
|
||||||
|
@ -172,3 +180,13 @@ class DataPrepare():
|
||||||
# batch = next(iter(train_loader))
|
# batch = next(iter(train_loader))
|
||||||
# batch['input_ids'].shape
|
# batch['input_ids'].shape
|
||||||
# # %%
|
# # %%
|
||||||
|
#
|
||||||
|
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
|
||||||
|
# tokens = tokenizer.tokenize(sentence)
|
||||||
|
# print("Tokens:", tokens)
|
||||||
|
# # Get the IDs (integer indices) of specific tokens
|
||||||
|
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
||||||
|
# print("Token IDs:", token_ids)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
|
|
@ -242,7 +242,8 @@ def train_step(state, x):
|
||||||
# with mesh: # not strictly necessary in this case
|
# with mesh: # not strictly necessary in this case
|
||||||
# with mesh block is useful for explicit scope for device sharding
|
# with mesh block is useful for explicit scope for device sharding
|
||||||
# but mesh management is automatic via jit sharding annotations
|
# but mesh management is automatic via jit sharding annotations
|
||||||
new_state = train_step(initialized_state, x)
|
with mesh:
|
||||||
|
new_state = train_step(initialized_state, x)
|
||||||
|
|
||||||
print(f'Sharding of Weight 1:')
|
print(f'Sharding of Weight 1:')
|
||||||
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||||||
|
|
|
@ -0,0 +1,854 @@
|
||||||
|
{
|
||||||
|
"transformer": {
|
||||||
|
"h": {
|
||||||
|
"0": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"attn": {
|
||||||
|
"attention": {
|
||||||
|
"k_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"out_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"q_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"v_proj": {
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_1": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ln_2": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"mlp": {
|
||||||
|
"c_fc": {
|
||||||
|
"bias": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
768,
|
||||||
|
3072
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"c_proj": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"kernel": [
|
||||||
|
3072,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ln_f": {
|
||||||
|
"bias": [
|
||||||
|
768
|
||||||
|
],
|
||||||
|
"scale": [
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"wpe": {
|
||||||
|
"embedding": [
|
||||||
|
2048,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"wte": {
|
||||||
|
"embedding": [
|
||||||
|
50257,
|
||||||
|
768
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
# %%
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax.core.frozen_dict import freeze, unfreeze
|
||||||
|
from partitions import set_partitions
|
||||||
|
|
||||||
|
from transformers import FlaxAutoModelForCausalLM
|
||||||
|
# this inits the model directly
|
||||||
|
model = FlaxAutoModelForCausalLM.from_pretrained(
|
||||||
|
"gpt-neo-125m",
|
||||||
|
)
|
||||||
|
params = model.params
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import json
|
||||||
|
shape_dict = jax.tree.map(jnp.shape, params)
|
||||||
|
# print(json.dumps(shape_dict, sort_keys=True, indent=4))
|
||||||
|
with open('gpt-neo-125m.json', 'w') as f:
|
||||||
|
json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
param_spec = set_partitions(unfreeze(params))
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,108 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Google Research Authors and The HuggingFace Team All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Utilities for constructing PyTrees of PartitionSpecs."""
|
||||||
|
|
||||||
|
# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from flax.core.frozen_dict import freeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinels
|
||||||
|
_unmatched = object()
|
||||||
|
|
||||||
|
# For specifying empty leaf dict `{}`
|
||||||
|
empty_dict = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _match(qs, ks):
|
||||||
|
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
||||||
|
# compile regexes and force complete match
|
||||||
|
qts = tuple((re.compile(x + "$") for x in qs))
|
||||||
|
for i in range(len(ks) - len(qs) + 1):
|
||||||
|
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
||||||
|
if matches and all(matches):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _replacement_rules(rules):
|
||||||
|
def replace(key, val):
|
||||||
|
for rule, replacement in rules:
|
||||||
|
if _match(rule, key):
|
||||||
|
return replacement
|
||||||
|
return val
|
||||||
|
|
||||||
|
return replace
|
||||||
|
|
||||||
|
|
||||||
|
# PartitionSpec for GPTNeo
|
||||||
|
# replicate the hidden dim and shard feed-forward and head dim
|
||||||
|
# def _get_partition_rules():
|
||||||
|
# return [
|
||||||
|
# # embeddings
|
||||||
|
# (("transformer", "wpe", "embedding"), P("mp", None)),
|
||||||
|
# (("transformer", "wte", "embedding"), P("mp", None)),
|
||||||
|
# # atention
|
||||||
|
# (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
||||||
|
# (("attention", "out_proj", "kernel"), P("mp", None)),
|
||||||
|
# (("attention", "out_proj", "bias"), None),
|
||||||
|
# # mlp
|
||||||
|
# (("mlp", "c_fc", "kernel"), P(None, "mp")),
|
||||||
|
# (("mlp", "c_fc", "bias"), P("mp")),
|
||||||
|
# (("mlp", "c_proj", "kernel"), P("mp", None)),
|
||||||
|
# (("mlp", "c_proj", "bias"), None),
|
||||||
|
# # layer norms
|
||||||
|
# ((r"ln_\d+", "bias"), None),
|
||||||
|
# ((r"\d+", r"ln_\d+", "scale"), None),
|
||||||
|
# (("ln_f", "bias"), None),
|
||||||
|
# (("ln_f", "scale"), None),
|
||||||
|
# ]
|
||||||
|
|
||||||
|
def _get_partition_rules():
|
||||||
|
return [
|
||||||
|
# embedding
|
||||||
|
(("shared", "embedding"), P("model", None)),
|
||||||
|
# SelfAttention
|
||||||
|
(("SelfAttention", "(q|k|v)", "kernel"), P(None, "model")),
|
||||||
|
(("SelfAttention", "o", "kernel"), P("model", None)),
|
||||||
|
(("SelfAttention", "relative_attention_bias", "embedding"), P(None)),
|
||||||
|
# EncDecAttention
|
||||||
|
(("EncDecAttention", "(q|k|v)", "kernel"), P(None, "model")),
|
||||||
|
(("EncDecAttention", "o", "kernel"), P("model", None)),
|
||||||
|
# DenseReluDense
|
||||||
|
(("DenseReluDense", "wi", "kernel"), P(None, "model")),
|
||||||
|
(("DenseReluDense", "wo", "kernel"), P("model", None)),
|
||||||
|
# layer norms
|
||||||
|
(("final_layer_norm", "weight"), P(None)),
|
||||||
|
(("layer_norm", "weight"), P(None)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def set_partitions(in_dict):
|
||||||
|
rules = _get_partition_rules()
|
||||||
|
replace = _replacement_rules(rules)
|
||||||
|
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
||||||
|
result = {k: replace(k, v) for k, v in initd.items()}
|
||||||
|
assert _unmatched not in result.values(), "Incomplete partition spec."
|
||||||
|
# for item in result.values():
|
||||||
|
# if item:
|
||||||
|
# print(item)
|
||||||
|
|
||||||
|
return freeze(unflatten_dict(result))
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,651 @@
|
||||||
|
# MARK: START
|
||||||
|
# %%
|
||||||
|
# let's make 8-device simulator
|
||||||
|
import sys
|
||||||
|
sys.dont_write_bytecode = True
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = True
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags += (
|
||||||
|
"--xla_gpu_enable_triton_softmax_fusion=true "
|
||||||
|
"--xla_gpu_triton_gemm_any=false "
|
||||||
|
"--xla_gpu_enable_async_collectives=true "
|
||||||
|
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
||||||
|
"--xla_gpu_enable_highest_priority_async_stream=true "
|
||||||
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from flax.core.frozen_dict import freeze, unfreeze
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
from partitions import set_partitions
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataload import DataPrepare
|
||||||
|
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# get platform type
|
||||||
|
from jax.lib import xla_bridge
|
||||||
|
print(xla_bridge.get_backend().platform)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
|
save_path = 't5_80_1_bf16'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 5
|
||||||
|
batch_size = 2 # 384 is the best
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 2e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=86,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
# batch
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# model
|
||||||
|
|
||||||
|
|
||||||
|
from t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
|
||||||
|
# from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
|
||||||
|
from t5_model.pure_t5 import make_config
|
||||||
|
config = make_config()
|
||||||
|
model = model_init(config)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
from transformers import T5Config
|
||||||
|
model, params = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# useful for transformer model
|
||||||
|
# model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
# enable bf16 except for layer_norm
|
||||||
|
# from flax import traverse_util
|
||||||
|
# flat_params = traverse_util.flatten_dict(model.params)
|
||||||
|
# mask = {
|
||||||
|
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
||||||
|
# }
|
||||||
|
# mask = traverse_util.unflatten_dict(mask)
|
||||||
|
# model.params = model.to_bf16(model.params, mask)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################################
|
||||||
|
# set partition on model
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # let's output the model parameters to a json file for study
|
||||||
|
# import json
|
||||||
|
# shape_dict = jax.tree.map(jnp.shape, params)
|
||||||
|
# # print(json.dumps(shape_dict, sort_keys=True, indent=4))
|
||||||
|
# with open('t5.json', 'w') as f:
|
||||||
|
# json.dump(shape_dict, fp=f, sort_keys=True, indent=2)
|
||||||
|
|
||||||
|
# MARK: setup mesh
|
||||||
|
# %%
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# optimizers
|
||||||
|
# %%
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# specify sharding
|
||||||
|
|
||||||
|
# shard data
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
|
||||||
|
# Defining the required dimensions for the self-attention layer input
|
||||||
|
# batch_size = 2
|
||||||
|
# seq_length = 768
|
||||||
|
# n_heads = 12
|
||||||
|
# head_dim = 768
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Create a large array with the shape (batch_size, seq_length, n_heads, head_dim)
|
||||||
|
# large_input = np.random.rand(2,768,768)
|
||||||
|
# batch = jax.device_put(large_input, x_sharding)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.debug.visualize_array_sharding(batch['input_ids'])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# shard output
|
||||||
|
# we will shard state by tracking its output upon jax.eval_shape after init
|
||||||
|
# define an init function to return a TrainState
|
||||||
|
# def init_fn(rng, batch, model, optimizer):
|
||||||
|
# # do be careful with the model init
|
||||||
|
# # imported models might have complicated init methods
|
||||||
|
# variables = model.init(rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
# )
|
||||||
|
# state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
# apply_fn=model.apply,
|
||||||
|
# params=variables['params'],
|
||||||
|
# tx=optimizer)
|
||||||
|
# return state
|
||||||
|
|
||||||
|
|
||||||
|
def init_fn(rng, batch, model, optimizer):
|
||||||
|
# do be careful with the model init
|
||||||
|
# imported models might have complicated init methods
|
||||||
|
variables = model.init(
|
||||||
|
rng,
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
)
|
||||||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=variables['params'],
|
||||||
|
tx=optimizer)
|
||||||
|
return state
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# alternative
|
||||||
|
# def init_fn(rng, batch, model, optimizer):
|
||||||
|
# # do be careful with the model init
|
||||||
|
# # imported models might have complicated init methods
|
||||||
|
# variables = model.init(
|
||||||
|
# rng, batch
|
||||||
|
# )
|
||||||
|
# state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
# apply_fn=model.apply,
|
||||||
|
# params=variables['params'],
|
||||||
|
# tx=optimizer)
|
||||||
|
# return state
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Create an abstract closure to wrap the function before feeding it in
|
||||||
|
# because `jax.eval_shape` only takes pytrees as arguments.
|
||||||
|
# eval_shape(fn, rng_key, x)
|
||||||
|
# used to perform shape inference
|
||||||
|
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
||||||
|
rng, init_rng = jax.random.split(rng)
|
||||||
|
abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# This `state_sharding` has the same pytree structure as `state`, the output
|
||||||
|
# of the `init_fn`.
|
||||||
|
# flan.linen.get_sharding
|
||||||
|
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
|
||||||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||||||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
print(state_sharding)
|
||||||
|
|
||||||
|
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jit_init_fn = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
|
||||||
|
out_shardings=state_sharding
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
rng, init_rng = jax.random.split(rng)
|
||||||
|
initialized_state = jit_init_fn(rng, batch, model, adamw)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# we can analyze the params structure
|
||||||
|
# for weight, partitioned in initialized_state.params['decoder'].items():
|
||||||
|
# print(f'Sharding of {weight}: {partitioned}')
|
||||||
|
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||||||
|
# jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
||||||
|
jax.tree.map(jnp.shape, initialized_state.params['decoder'])
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print(initialized_state.params['decoder']['block']['0']['layer']['0']['SelfAttention']['k']['kernel'].value.sharding)
|
||||||
|
print(initialized_state.step)
|
||||||
|
print(initialized_state.step.sharding)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.sum()
|
||||||
|
num_labels = padding_mask.sum()
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# single device code annotated with jax.jit
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
# in_shardings=(state_sharding, x_sharding),
|
||||||
|
out_shardings=state_sharding
|
||||||
|
)
|
||||||
|
def train_step(state, batch):
|
||||||
|
label_smoothing_factor=0.0
|
||||||
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
|
||||||
|
def compute_loss(params):
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
logits = state.apply_fn(
|
||||||
|
{'params': params},
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
)[0]
|
||||||
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# compute gradients through computational graph
|
||||||
|
# allow values to pass through
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||||
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
|
# num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
# grad = jax.lax.psum(grad, "batch")
|
||||||
|
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
|
new_state = state.apply_gradients(grads=grad)
|
||||||
|
|
||||||
|
# metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# variables = model.init(
|
||||||
|
# rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
# )
|
||||||
|
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
# batch = {key: jax.device_put(jnp.array(value), x_sharding) for key, value in batch.items()}
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
new_state = train_step(initialized_state, batch)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # %%
|
||||||
|
# #############################################################
|
||||||
|
# # we cannot integrate our model pspec with train_state
|
||||||
|
# # we just shard separately
|
||||||
|
# # update: we also cannot use the method of modifying a partitionspec tree
|
||||||
|
# # we have to do it the RIGHT way, following flax_pjit_tutorial to the letter
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# def get_optim_initial_state(params):
|
||||||
|
# params = params
|
||||||
|
# state = adamw.init(params)
|
||||||
|
# return tuple((state)), params
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # create partitions for model
|
||||||
|
# from partitions import set_partitions
|
||||||
|
# # set_partitions freezes the params on return
|
||||||
|
# model_param_spec = set_partitions(unfreeze(params))
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# params_shapes = jax.tree.map(lambda x: x.shape, params)
|
||||||
|
# # actually tuple
|
||||||
|
# optim_state_shapes = jax.eval_shape(get_optim_initial_state, params_shapes)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # get pspec for opt_state
|
||||||
|
# def get_opt_spec(x):
|
||||||
|
# if isinstance(x, dict):
|
||||||
|
# return unfreeze(model_param_spec)
|
||||||
|
# return PartitionSpec()
|
||||||
|
#
|
||||||
|
# # this function replaces the empty model params spec with the 'model_param_spec'
|
||||||
|
# opt_state_spec, param_spec = jax.tree.map(
|
||||||
|
# get_opt_spec, optim_state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
#
|
||||||
|
# model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
|
||||||
|
# # store on cpu
|
||||||
|
# model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
# print(device_mesh)
|
||||||
|
#
|
||||||
|
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
# print(mesh)
|
||||||
|
#
|
||||||
|
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
# return NamedSharding(mesh, pspec)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # opt_state_sharding = mesh_sharding(opt_state_spec)
|
||||||
|
# # param_sharding = mesh_sharding(param_spec)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# opt_state_sharding = nn.get_sharding(opt_state_spec, mesh)
|
||||||
|
# param_sharding = nn.get_sharding(param_spec, mesh)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # jit the get_initial_state function to shard params and init optimizer state in
|
||||||
|
# # a sharded way
|
||||||
|
# from jax.experimental.pjit import pjit
|
||||||
|
#
|
||||||
|
# with mesh:
|
||||||
|
# p_get_initial_state = pjit(
|
||||||
|
# get_optim_initial_state,
|
||||||
|
# in_shardings=None,
|
||||||
|
# out_shardings=(opt_state_spec, param_spec),
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # Convert your PartitionSpec to NamedSharding for model params
|
||||||
|
# param_sharding = NamedSharding(mesh, freeze(param_spec))
|
||||||
|
# # Use device_put with sharding to move params onto the mesh
|
||||||
|
# sharded_params = jax.device_put(freeze(params), param_sharding)
|
||||||
|
#
|
||||||
|
# with mesh:
|
||||||
|
# # params is already frozen
|
||||||
|
# sharded_opt_state, sharded_params = p_get_initial_state(unfreeze(sharded_params))
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
#
|
||||||
|
# # give up this section
|
||||||
|
# #############################################################
|
||||||
|
# # create train state
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # Initialize random key and input for initialization
|
||||||
|
# rng = jax.random.PRNGKey(seed)
|
||||||
|
# loader_rng, rng = jax.random.split(rng)
|
||||||
|
# train_loader = dataprep.data_loader(rng, batch_size=2)
|
||||||
|
# batch = next(iter(train_loader))
|
||||||
|
#
|
||||||
|
# # use the T5 base model to do this
|
||||||
|
# from transformers import FlaxAutoModel
|
||||||
|
# model, params = FlaxAutoModel.from_pretrained(
|
||||||
|
# 't5-base',
|
||||||
|
# _do_init=False
|
||||||
|
# )
|
||||||
|
# t5_module = model.module
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# init_rng, rng = jax.random.split(rng)
|
||||||
|
# variables = t5_module.init(init_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
# )
|
||||||
|
# params = variables['params']
|
||||||
|
#
|
||||||
|
# # create an init function
|
||||||
|
# # %%
|
||||||
|
# # we will shard state by tracking its output upon jax.eval_shape after init
|
||||||
|
# # define an init function to return a TrainState
|
||||||
|
# def init_fn(rng: jax.random.PRNGKey, batch=batch, model=t5_module, optimizer=adamw) -> train_state.TrainState:
|
||||||
|
# init_rng, rng = jax.random.split(rng)
|
||||||
|
# variables = model.init(
|
||||||
|
# init_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_attention_mask'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
# )
|
||||||
|
# params = variables.pop("params")
|
||||||
|
# state = train_state.TrainState.create(
|
||||||
|
# apply_fn=model.__call__,
|
||||||
|
# params=params,
|
||||||
|
# tx=optimizer,
|
||||||
|
# )
|
||||||
|
# return state
|
||||||
|
#
|
||||||
|
# # model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base", _do_init=True)
|
||||||
|
# # Create an abstract closure to wrap the function before feeding it in
|
||||||
|
# # because `jax.eval_shape` only takes pytrees as arguments.
|
||||||
|
# # eval_shape(fn, rng_key, x)
|
||||||
|
# # used to perform shape inference
|
||||||
|
# # returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
||||||
|
# init_rng, rng = jax.random.split(rng)
|
||||||
|
# abstract_variables = jax.eval_shape(
|
||||||
|
# functools.partial(init_fn, model=t5_module, optimizer=adamw),
|
||||||
|
# init_rng,
|
||||||
|
# batch
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # let's make our mesh
|
||||||
|
#
|
||||||
|
# device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
# print(device_mesh)
|
||||||
|
#
|
||||||
|
# mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
# print(mesh)
|
||||||
|
#
|
||||||
|
# def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
# return NamedSharding(mesh, pspec)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # making jax compatible batch
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
# # batch = jax.device_put(batch), x_sharding)
|
||||||
|
# # jax.debug.visualize_array_sharding(batch)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
# print(state_sharding)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # integrate model_param_specs and state_out_specs
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # i want to make a Sharding object
|
||||||
|
# # model_sharding = mesh_sharding(model_param_spec)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# jit_init_fn = jax.jit(
|
||||||
|
# init_fn, # rng, batch, model, optimizer
|
||||||
|
# static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
# in_shardings=(mesh_sharding(()), x_sharding), # mesh_sharding(()), mesh_sharding(())), # for PRNG key and data
|
||||||
|
# out_shardings=state_sharding
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
#
|
||||||
|
# init_rng, rng = jax.random.split(rng)
|
||||||
|
# initialized_state = jit_init_fn(
|
||||||
|
# init_rng,
|
||||||
|
# batch,
|
||||||
|
# t5_module,
|
||||||
|
# adamw)
|
||||||
|
#
|
||||||
|
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
|
||||||
|
# # jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
#
|
||||||
|
# %%
|
87
t5_jax.py
87
t5_jax.py
|
@ -46,11 +46,13 @@ from datasets import load_from_disk
|
||||||
|
|
||||||
|
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
from typing import Dict, Any, Union
|
||||||
|
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||||
import flax.core
|
import flax.core
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,9 +62,20 @@ import time
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
import os
|
import os
|
||||||
os.environ['XLA_FLAGS'] = (
|
# os.environ['XLA_FLAGS'] = (
|
||||||
'--xla_gpu_triton_gemm_any=true --xla_gpu_enable_custom_fusions=true --xla_gpu_enable_address_computation_fusion=true'
|
# '--xla_gpu_triton_gemm_any=true '
|
||||||
|
# '--xla_gpu_enable_custom_fusions=true '
|
||||||
|
# '--xla_gpu_enable_address_computation_fusion=true'
|
||||||
|
# )
|
||||||
|
flags = (
|
||||||
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||||
|
'--xla_gpu_triton_gemm_any=True '
|
||||||
|
# '--xla_gpu_enable_async_collectives=true '
|
||||||
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||||
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||||
)
|
)
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
|
||||||
|
|
||||||
os.environ.update({
|
os.environ.update({
|
||||||
"TOKENIZERS_PARALLELISM" : "false",
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
|
@ -76,8 +89,8 @@ os.environ.update({
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# get platform type
|
# get platform type
|
||||||
from jax.lib import xla_bridge
|
from jax.extend.backend import get_backend
|
||||||
print(xla_bridge.get_backend().platform)
|
print(get_backend().platform)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -90,14 +103,14 @@ except (LookupError, OSError):
|
||||||
# %%
|
# %%
|
||||||
# config options
|
# config options
|
||||||
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval'
|
||||||
save_path = 't5_80_1_bf16'
|
save_path = 't5_5e_1_pmap'
|
||||||
# file_path = 'combined_data'
|
# file_path = 'combined_data'
|
||||||
split_datasets = load_from_disk(file_path)
|
split_datasets = load_from_disk(file_path)
|
||||||
training_size = len(split_datasets['train'])
|
training_size = len(split_datasets['train'])
|
||||||
# Store some constant
|
# Store some constant
|
||||||
seed = 117
|
seed = 117
|
||||||
num_epochs = 5
|
num_epochs = 5
|
||||||
batch_size = 384 # 384 is the best
|
batch_size = 32 # 384 is the best
|
||||||
num_train_epochs = num_epochs
|
num_train_epochs = num_epochs
|
||||||
per_device_train_batch_size = batch_size
|
per_device_train_batch_size = batch_size
|
||||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
@ -116,7 +129,7 @@ adam_epsilon = 1e-8
|
||||||
label_smoothing_factor = 0.0
|
label_smoothing_factor = 0.0
|
||||||
|
|
||||||
num_beams = 1
|
num_beams = 1
|
||||||
val_max_target_length = 128
|
val_max_target_length = 86
|
||||||
|
|
||||||
predict_with_generate = True
|
predict_with_generate = True
|
||||||
|
|
||||||
|
@ -126,7 +139,7 @@ predict_with_generate = True
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||||
# Define additional special tokens
|
# Define additional special tokens
|
||||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||||
# Add the additional special tokens to the tokenizer
|
# Add the additional special tokens to the tokenizer
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
@ -172,15 +185,43 @@ from flax import traverse_util
|
||||||
|
|
||||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||||
# useful for transformer model
|
# useful for transformer model
|
||||||
model.enable_gradient_checkpointing()
|
# model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
# enable bf16 except for layer_norm
|
# enable bf16 except for layer_norm
|
||||||
flat_params = traverse_util.flatten_dict(model.params)
|
# flat_params = traverse_util.flatten_dict(model.params)
|
||||||
mask = {
|
# mask = {
|
||||||
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
# path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
||||||
}
|
# }
|
||||||
mask = traverse_util.unflatten_dict(mask)
|
# mask = traverse_util.unflatten_dict(mask)
|
||||||
model.params = model.to_bf16(model.params, mask)
|
# # borrowed from transformers modeling_flax_utils
|
||||||
|
# def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||||
|
# """
|
||||||
|
# Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
||||||
|
# def conditional_cast(param):
|
||||||
|
# if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
||||||
|
# param = param.astype(dtype)
|
||||||
|
# return param
|
||||||
|
#
|
||||||
|
# if mask is None:
|
||||||
|
# return jax.tree_util.tree_map(conditional_cast, params)
|
||||||
|
#
|
||||||
|
# flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
||||||
|
#
|
||||||
|
# for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
||||||
|
# if masked:
|
||||||
|
# flat_params[key] = conditional_cast(flat_params[key])
|
||||||
|
#
|
||||||
|
# return traverse_util.unflatten_dict(flat_params)
|
||||||
|
#
|
||||||
|
# # Cast parameters to bfloat16 if desired
|
||||||
|
# # params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||||
|
# # instead of casting the whole thing, we cast only certain parts of the tree
|
||||||
|
# params = cast_floating_to(model.params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# # Function to extract shape and dtype without showing values
|
# # Function to extract shape and dtype without showing values
|
||||||
|
@ -527,10 +568,12 @@ for epoch in epochs:
|
||||||
# jax.profiler.stop_trace()
|
# jax.profiler.stop_trace()
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
# output_dir = save_path
|
output_dir = save_path
|
||||||
# # save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
# if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
# params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
# params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
# model.save_pretrained(output_dir, params=params)
|
model.save_pretrained(output_dir, params=params)
|
||||||
# tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
|
@ -0,0 +1,697 @@
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags = (
|
||||||
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||||
|
'--xla_gpu_triton_gemm_any=True '
|
||||||
|
# '--xla_gpu_enable_async_collectives=true '
|
||||||
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||||
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||||
|
)
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
os.environ.update({
|
||||||
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
|
||||||
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
# model checkpointing and saving utilities
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.training import checkpoints, train_state
|
||||||
|
from flax import struct, serialization
|
||||||
|
import orbax.checkpoint as ocp
|
||||||
|
from flax.training import orbax_utils
|
||||||
|
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parallel.dataload import DataPrepare
|
||||||
|
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
## get platform type
|
||||||
|
from jax.extend.backend import get_backend
|
||||||
|
print(get_backend().platform)
|
||||||
|
print(jax.devices())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
||||||
|
save_path = '/home/richard/Projects/06_research/jax_models/t5_80e_fp32_parallel/'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 5
|
||||||
|
batch_size = 32
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 5e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
print("preparing data")
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=128,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
# batch
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# model
|
||||||
|
|
||||||
|
# working
|
||||||
|
# from parallel.t5_model.pure_t5 import FlaxT5ForConditionalGenerationModule as model_init
|
||||||
|
# # from t5_model.pure_t5 import FlaxT5DenseActDense as model_init
|
||||||
|
# from parallel.t5_model.pure_t5 import make_config
|
||||||
|
# config = make_config()
|
||||||
|
# model = model_init(config=config, dtype=jnp.bfloat16, gradient_checkpointing=True)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# from transformers import FlaxT5ForConditionalGeneration, T5Config
|
||||||
|
# model = FlaxT5ForConditionalGeneration.from_pretrained(
|
||||||
|
# "t5-base",
|
||||||
|
# dtype=jnp.bfloat16,
|
||||||
|
# )
|
||||||
|
# # pretrained_params = model.params
|
||||||
|
# model = model.module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
|
||||||
|
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
|
||||||
|
main_model = custom_model.from_pretrained(
|
||||||
|
"t5-base",
|
||||||
|
dtype=jnp.float32,
|
||||||
|
# gradient_checkpointing=True,
|
||||||
|
)
|
||||||
|
params = main_model.params
|
||||||
|
# pretrained_params = model.params
|
||||||
|
model = main_model.module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # testing config hashability
|
||||||
|
# # some explanation:
|
||||||
|
# # The PreTrainedModel class loads a T5Config model that is not hashable because
|
||||||
|
# # it is a complicated class that pretends to be a dataclass.
|
||||||
|
# # The solution is to extract a dict from it, then make a ConfigDict from
|
||||||
|
# # ml_collections library so that we can get values via the "." operator.
|
||||||
|
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
|
||||||
|
# # modify the config before passing to the next layer
|
||||||
|
# from transformers import T5Config
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config
|
||||||
|
# from ml_collections import ConfigDict, FrozenConfigDict
|
||||||
|
#
|
||||||
|
# config = T5Config.from_pretrained("t5-base").to_dict()
|
||||||
|
# config.pop('architectures')
|
||||||
|
# config.pop('id2label')
|
||||||
|
# # test if it works
|
||||||
|
# frozen_config = FrozenConfigDict(config)
|
||||||
|
# # test hash
|
||||||
|
# hash(frozen_config)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # print model
|
||||||
|
# rng, input_rng = jax.random.split(rng)
|
||||||
|
# model.tabulate(
|
||||||
|
# input_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
# console_kwargs={"force_jupyter": True}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# print model datatype to verify
|
||||||
|
# rng, input_rng = jax.random.split(rng)
|
||||||
|
# variables = model.init(
|
||||||
|
# input_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# create mesh
|
||||||
|
print("creating mesh")
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((1,1))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||||
|
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# optimizers
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
print("compile")
|
||||||
|
|
||||||
|
|
||||||
|
# enable bf16 except for layer_norm
|
||||||
|
def create_mask_for_layer_norm(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
mask = {
|
||||||
|
path: not (path[-2] == "layer_norm" and path[-1] == "weight") for path in flat_params
|
||||||
|
}
|
||||||
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
# borrowed from transformers modeling_flax_utils
|
||||||
|
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
||||||
|
def conditional_cast(param):
|
||||||
|
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
||||||
|
param = param.astype(dtype)
|
||||||
|
return param
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
return jax.tree_util.tree_map(conditional_cast, params)
|
||||||
|
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
||||||
|
|
||||||
|
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
||||||
|
if masked:
|
||||||
|
flat_params[key] = conditional_cast(flat_params[key])
|
||||||
|
|
||||||
|
return traverse_util.unflatten_dict(flat_params)
|
||||||
|
|
||||||
|
# Cast all parameters to bfloat16 if desired
|
||||||
|
# params = jax.tree.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def init_fn(params, model, optimizer):
|
||||||
|
# do be careful with the model init
|
||||||
|
# imported models might have complicated init methods
|
||||||
|
# mask = create_mask_for_layer_norm(params)
|
||||||
|
# override params with bfloat version
|
||||||
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optimizer)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# def init_fn(rng, batch, model, optimizer):
|
||||||
|
# # do be careful with the model init
|
||||||
|
# # imported models might have complicated init methods
|
||||||
|
# variables = model.init(
|
||||||
|
# rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask']
|
||||||
|
# )
|
||||||
|
# params = variables['params']
|
||||||
|
# mask = create_mask_for_layer_norm(params)
|
||||||
|
# # override params with bfloat version
|
||||||
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
#
|
||||||
|
# state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
# apply_fn=model.apply,
|
||||||
|
# params=params,
|
||||||
|
# tx=optimizer)
|
||||||
|
# return state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Create an abstract closure to wrap the function before feeding it in
|
||||||
|
# because `jax.eval_shape` only takes pytrees as arguments.
|
||||||
|
# eval_shape(fn, rng_key, x)
|
||||||
|
# used to perform shape inference
|
||||||
|
# returns a nested PyTree containing jax.ShapeDtypeStruct objects as leaves
|
||||||
|
# rng, init_rng = jax.random.split(rng)
|
||||||
|
abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=model, optimizer=adamw), params)
|
||||||
|
|
||||||
|
# rng, init_rng = jax.random.split(rng)
|
||||||
|
# abstract_variables = jax.eval_shape(
|
||||||
|
# functools.partial(init_fn, model=model, optimizer=adamw), init_rng, batch)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# This `state_sharding` has the same pytree structure as `state`, the output
|
||||||
|
# of the `init_fn`.
|
||||||
|
# flan.linen.get_sharding
|
||||||
|
# extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh
|
||||||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||||||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
# print(state_sharding)
|
||||||
|
|
||||||
|
# warning: do not have singleton None in your nn.partition definitions, it will screw with your sanity
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# # %%
|
||||||
|
# # replace the params tree with the new modified tree
|
||||||
|
# # create partitions for model
|
||||||
|
# from parallel.partitions import set_partitions
|
||||||
|
# # set_partitions freezes the params on return
|
||||||
|
# model_part_spec = set_partitions(unfreeze(params))
|
||||||
|
# # p is already a partition spec
|
||||||
|
# model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # get_shapes = jax.tree.map(jnp.shape, params)
|
||||||
|
# # actually tuple
|
||||||
|
# # state_shapes = jax.eval_shape(state_sharding, get_shapes)
|
||||||
|
#
|
||||||
|
# # %%
|
||||||
|
# # get pspec for opt_state
|
||||||
|
# def get_opt_spec(x):
|
||||||
|
# if isinstance(x, dict):
|
||||||
|
# return unfreeze(model_named_sharding)
|
||||||
|
# # return an empty partspec
|
||||||
|
# return mesh_sharding((PartitionSpec()))
|
||||||
|
#
|
||||||
|
# # this function replaces the empty model params spec with the 'model_named_shard'
|
||||||
|
# state_sharding = jax.tree.map(
|
||||||
|
# get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
jit_init_fn = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
in_shardings=mesh_sharding(PartitionSpec(())), # we don't shard params explicitly
|
||||||
|
out_shardings=state_sharding # but returned initialized_state is sharded
|
||||||
|
)
|
||||||
|
initialized_state = jit_init_fn(params, model, adamw)
|
||||||
|
|
||||||
|
|
||||||
|
# jit_init_fn = jax.jit(
|
||||||
|
# init_fn,
|
||||||
|
# static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
# in_shardings=(mesh_sharding(()), x_sharding), # for PRNG key and data
|
||||||
|
# out_shardings=state_sharding
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# rng, init_rng = jax.random.split(rng)
|
||||||
|
# initialized_state = jit_init_fn(rng, batch, model, adamw)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.sum()
|
||||||
|
num_labels = padding_mask.sum()
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# sharded_loss_fn = jax.jit(
|
||||||
|
# loss_fn,
|
||||||
|
# in_shardings=(mesh_sharding('model'), x_sharding), # params partitioned across 'model' axis
|
||||||
|
# out_shardings=(mesh_sharding('model')), # Loss should be aggregated across 'model'
|
||||||
|
# )
|
||||||
|
|
||||||
|
def gather_and_sum(
|
||||||
|
sharded_values,
|
||||||
|
in_shardings
|
||||||
|
):
|
||||||
|
with mesh:
|
||||||
|
# Gather sharded values into a single device
|
||||||
|
gathered_values = jax.jit(
|
||||||
|
lambda x: x, in_shardings=in_shardings, out_shardings=None
|
||||||
|
)(sharded_values)
|
||||||
|
|
||||||
|
# Compute the sum of gathered values
|
||||||
|
summed_value = jax.tree.map(lambda x: jnp.sum(x), gathered_values)
|
||||||
|
return summed_value
|
||||||
|
|
||||||
|
|
||||||
|
# single device code annotated with jax.jit
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
# state is state_sharding initialized from init_fn
|
||||||
|
# x_sharding is data sharded explicitly later
|
||||||
|
in_shardings=(state_sharding, x_sharding),
|
||||||
|
# return state as state_sharding
|
||||||
|
# we do not shard the metrics
|
||||||
|
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
|
||||||
|
donate_argnames=('state'),
|
||||||
|
)
|
||||||
|
def train_step(state, batch):
|
||||||
|
|
||||||
|
label_smoothing_factor=0.0
|
||||||
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
def compute_loss(params, batch):
|
||||||
|
# check constraints
|
||||||
|
# frozen dict not allowed as sharding object
|
||||||
|
# params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
|
||||||
|
# batch = jax.lax.with_sharding_constraint(batch, x_sharding)
|
||||||
|
# labels = batch.pop("decoder_input_ids")
|
||||||
|
# no use of labels here
|
||||||
|
logits = state.apply_fn(
|
||||||
|
{'params': params},
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
)[0] # zero because output is some structure, where first is the logit
|
||||||
|
# use labels here
|
||||||
|
loss, num_labels = loss_fn(
|
||||||
|
logits,
|
||||||
|
batch["labels"],
|
||||||
|
batch["decoder_attention_mask"],
|
||||||
|
label_smoothing_factor)
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
|
# compute gradients through computational graph
|
||||||
|
# allow values to pass through
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||||
|
(loss, num_labels), grad = grad_fn(state.params, batch)
|
||||||
|
# num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
# needs to be in a singleton tuple for some reason
|
||||||
|
# gathered_grad = gather_and_sum(grad, (unfreeze(model_named_sharding),))
|
||||||
|
|
||||||
|
# gathered_num_labels = gather_and_sum(num_labels, mesh_sharding(PartitionSpec()))
|
||||||
|
|
||||||
|
# summed_gradients = jax.tree.map(lambda x: jnp.sum(x)/gathered_num_labels, gathered_grad)
|
||||||
|
# grad = jax.lax.psum(grad, "batch")
|
||||||
|
# grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
|
|
||||||
|
new_state = state.apply_gradients(grads=grad)
|
||||||
|
with jax.named_scope("sync_metrics"):
|
||||||
|
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
# step_metrics = jax.tree.map(
|
||||||
|
# # previously needed lax.psum
|
||||||
|
# # now just write single device code, let compiler handle
|
||||||
|
# lambda x: jnp.mean(x), step_metrics
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if metrics is None:
|
||||||
|
# metrics = step_metrics
|
||||||
|
# else:
|
||||||
|
# # combine all the synced metrics
|
||||||
|
# metrics = jax.tree.map(jnp.mean, metrics, step_metrics)
|
||||||
|
|
||||||
|
|
||||||
|
return new_state, step_metrics
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prep 1 step
|
||||||
|
print("1 step for jit-ting")
|
||||||
|
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
state, metrics = train_step(initialized_state, batch)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# tr
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {training_size}")
|
||||||
|
print(f" Num Epochs = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
# function to shard a batch by treating it as a pytree
|
||||||
|
def shard_batch(batch):
|
||||||
|
# Shard each element in the dictionary (i.e., each key-value pair)
|
||||||
|
return jax.tree_util.tree_map(
|
||||||
|
lambda x: jax.device_put(x, x_sharding),
|
||||||
|
batch
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
print("*" * 10)
|
||||||
|
print("training start")
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_time = 0
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
|
train_metrics = []
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
|
batch = next(train_loader)
|
||||||
|
# send to device
|
||||||
|
# batch = {key: jax.device_put(jnp.array(value, dtype=jnp.uint16), x_sharding) for key, value in batch.items()}
|
||||||
|
# batch['input_ids']=jax.device_put(jnp.array(batch['input_ids'], dtype=jnp.int32), x_sharding)
|
||||||
|
# batch['attention_mask']=jax.device_put(jnp.array(batch['attention_mask'], dtype=jnp.int32), x_sharding)
|
||||||
|
# batch['decoder_input_ids']=jax.device_put(jnp.array(batch['decoder_input_ids'], dtype=jnp.int32), x_sharding)
|
||||||
|
# batch['decoder_attention_mask']=jax.device_put(jnp.array(batch['decoder_attention_mask'], dtype=jnp.int32), x_sharding)
|
||||||
|
sharded_batch = shard_batch(batch)
|
||||||
|
with mesh:
|
||||||
|
state, train_metric = train_step(state, sharded_batch)
|
||||||
|
|
||||||
|
# train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
|
||||||
|
# this is for more accurate time stats, but slows down training
|
||||||
|
# train_metric['loss'].block_until_ready()
|
||||||
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
|
f"Loss: {train_metric['loss']}, "
|
||||||
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
# %%
|
||||||
|
# with mesh:
|
||||||
|
# gathered_params = jax.jit(
|
||||||
|
# lambda x: x,
|
||||||
|
# in_shardings=(unfreeze(model_named_sharding),),
|
||||||
|
# out_shardings=mesh_sharding(PartitionSpec())
|
||||||
|
# )(state.params)
|
||||||
|
|
||||||
|
main_model = custom_model.from_pretrained('t5-base')
|
||||||
|
output_dir = save_path
|
||||||
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
|
main_model.save_pretrained(output_dir, params=params)
|
||||||
|
|
||||||
|
# # stick to defaults
|
||||||
|
# options = ocp.CheckpointManagerOptions()
|
||||||
|
# with ocp.CheckpointManager(
|
||||||
|
# ocp.test_utils.erase_and_create_empty(save_path),
|
||||||
|
# options=options,
|
||||||
|
# ) as mngr:
|
||||||
|
#
|
||||||
|
# mngr.save(0, args=ocp.args.StandardSave(state))
|
||||||
|
# mngr.wait_until_finished()
|
||||||
|
|
||||||
|
# After providing `args` during an initial `save` or `restore` call, the
|
||||||
|
# `CheckpointManager` instance records the type so that you do not need to
|
||||||
|
# specify it again. If the `CheckpointManager` instance is not provided with a
|
||||||
|
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
|
||||||
|
# occasion it cannot be restored without specifying the argument at restore
|
||||||
|
# time.
|
||||||
|
|
||||||
|
# # In many cases, you can restore exactly as saved without specifying additional
|
||||||
|
# # arguments.
|
||||||
|
# mngr.restore(0)
|
||||||
|
# # If customization of properties like sharding or dtype is desired, just provide
|
||||||
|
# # the abstract target PyTree, the properties of which will be used to set
|
||||||
|
# # the properties of the restored arrays.
|
||||||
|
# mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
|
||||||
|
|
||||||
|
# %%
|
|
@ -13,9 +13,12 @@
|
||||||
# # prediction code
|
# # prediction code
|
||||||
# ## import and process test data
|
# ## import and process test data
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# import libraries
|
# import libraries
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
@ -35,7 +38,7 @@ jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
jax.config.update("jax_enable_x64", False)
|
jax.config.update("jax_enable_x64", False)
|
||||||
|
|
||||||
|
|
||||||
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
# from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
@ -51,9 +54,13 @@ from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from parallel.dataload import DataPrepare
|
||||||
|
import orbax.checkpoint as ocp
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
@ -86,121 +93,22 @@ def process_df(df):
|
||||||
# takes 1 minute to run without batching
|
# takes 1 minute to run without batching
|
||||||
test_dataset = Dataset.from_list(process_df(df))
|
test_dataset = Dataset.from_list(process_df(df))
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# ## Load model for attributes
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# load model
|
# from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration
|
||||||
model_name_or_path = "./t5_80_1_bf16" # Replace with your specific model name
|
from transformers import FlaxT5ForConditionalGeneration
|
||||||
|
# model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
||||||
# Load configuration
|
model_name_or_path = "./model_checkpoints/simple_test" # Replace with your specific model name
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
model = FlaxT5ForConditionalGeneration.from_pretrained(model_name_or_path)
|
||||||
|
params = model.params
|
||||||
# Load model
|
|
||||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path=model_name_or_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# ## Tokenizer
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# prepare tokenizer
|
|
||||||
from transformers import T5TokenizerFast
|
|
||||||
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
|
||||||
# Define additional special tokens
|
|
||||||
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
|
||||||
# Add the additional special tokens to the tokenizer
|
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
||||||
|
|
||||||
max_length = 86
|
|
||||||
|
|
||||||
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
|
||||||
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
|
||||||
|
|
||||||
# given a dataset entry, run it through the tokenizer
|
|
||||||
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
|
||||||
def preprocess_function(example):
|
|
||||||
inputs = example['input']
|
|
||||||
targets = example['output']
|
|
||||||
# text_target sets the corresponding label to inputs
|
|
||||||
# there is no need to create a separate 'labels'
|
|
||||||
model_inputs = tokenizer(
|
|
||||||
inputs,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="np"
|
|
||||||
)
|
|
||||||
labels = tokenizer(
|
|
||||||
text_target=targets,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="np"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs["labels"] = labels["input_ids"]
|
|
||||||
decoder_input_ids = shift_tokens_right_fn(
|
|
||||||
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
|
||||||
)
|
|
||||||
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
|
||||||
|
|
||||||
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
|
||||||
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
# map maps function to each "row" in the dataset
|
|
||||||
# aka the data in the immediate nesting
|
|
||||||
test_dataset = test_dataset.map(
|
|
||||||
preprocess_function,
|
|
||||||
batched=True,
|
|
||||||
num_proc=1,
|
|
||||||
remove_columns=test_dataset.column_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
|
||||||
"""
|
|
||||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
|
||||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
|
||||||
"""
|
|
||||||
if shuffle:
|
|
||||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
||||||
batch_idx = np.asarray(batch_idx)
|
|
||||||
else:
|
|
||||||
batch_idx = np.arange(len(dataset))
|
|
||||||
|
|
||||||
if drop_last:
|
|
||||||
steps_per_epoch = len(dataset) // batch_size
|
|
||||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
||||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
||||||
else:
|
|
||||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
||||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
|
||||||
|
|
||||||
for idx in batch_idx:
|
|
||||||
batch = dataset[idx]
|
|
||||||
batch = {k: np.array(v) for k, v in batch.items()}
|
|
||||||
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# # model generation
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
seed = 117
|
seed = 117
|
||||||
num_epochs = 80
|
batch_size = 128
|
||||||
batch_size = 96
|
|
||||||
num_train_epochs = num_epochs
|
|
||||||
per_device_train_batch_size = batch_size
|
per_device_train_batch_size = batch_size
|
||||||
train_batch_size = per_device_train_batch_size * jax.device_count()
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
per_device_eval_batch_size = batch_size
|
per_device_eval_batch_size = batch_size
|
||||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
steps_per_epoch = len(test_dataset) // train_batch_size
|
steps_per_epoch = len(test_dataset) // train_batch_size
|
||||||
total_train_steps = steps_per_epoch * num_epochs
|
|
||||||
|
|
||||||
num_beams = 1
|
num_beams = 1
|
||||||
val_max_target_length = 128
|
val_max_target_length = 128
|
||||||
|
@ -208,33 +116,31 @@ val_max_target_length = 128
|
||||||
predict_with_generate = True
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
# Initialize our training
|
# Initialize our prediction
|
||||||
rng = jax.random.PRNGKey(seed)
|
rng = jax.random.PRNGKey(seed)
|
||||||
rng, dropout_rng = jax.random.split(rng)
|
rng, dropout_rng = jax.random.split(rng)
|
||||||
|
|
||||||
|
print("preparing data")
|
||||||
# %%
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
# reload model to prevent leakage of variables
|
max_length=128,
|
||||||
# load model
|
pad_token_id=0,
|
||||||
model_name_or_path = "t5_80_1" # Replace with your specific model name
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
# Load configuration
|
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
|
||||||
model_name_or_path
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(test_dataset, data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
# Ensure model.params is properly initialized (this is just an example)
|
# Ensure model.params is properly initialized (this is just an example)
|
||||||
# Normally you would get this from a model initialization call with dummy input
|
# Normally you would get this from a model initialization call with dummy input
|
||||||
params = model.params
|
|
||||||
# ensure full size floats
|
|
||||||
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
|
||||||
# we need to replicate model over devices
|
# we need to replicate model over devices
|
||||||
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
|
replicated_params = jax.device_put_replicated(params, jax.devices())
|
||||||
|
|
||||||
|
|
||||||
# Define generation function
|
# Define generation function
|
||||||
|
@ -245,25 +151,29 @@ num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
|
|
||||||
def generate_step(params, batch):
|
def generate_step(params, batch):
|
||||||
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
|
output_ids = model.generate(
|
||||||
|
batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
params=params,
|
||||||
|
**gen_kwargs)
|
||||||
return output_ids.sequences
|
return output_ids.sequences
|
||||||
|
|
||||||
# Create parallel version of the train and eval step
|
# Create parallel version of the train and eval step
|
||||||
p_generate_step = jax.pmap(generate_step, "batch")
|
p_generate_step = jax.pmap(generate_step, "batch")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pred_generations = []
|
pred_generations = []
|
||||||
pred_labels = []
|
pred_labels = []
|
||||||
|
decoder_input_list = []
|
||||||
|
|
||||||
rng, input_rng = jax.random.split(rng)
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
|
||||||
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
|
|
||||||
|
pred_loader = dataprep.data_loader(input_rng, batch_size=batch_size, shuffle=False, drop_last=False)
|
||||||
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
||||||
|
|
||||||
print("***** Running training *****")
|
print("***** Running training *****")
|
||||||
print(f" Num examples = {len(test_dataset)}")
|
print(f" Num examples = {len(test_dataset)}")
|
||||||
print(f" Num steps = {num_epochs}")
|
# print(f" Num steps = {num_epochs}")
|
||||||
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
|
||||||
|
@ -272,26 +182,38 @@ for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
||||||
# Model forward
|
# Model forward
|
||||||
batch = next(pred_loader)
|
batch = next(pred_loader)
|
||||||
labels = batch["labels"]
|
labels = batch["labels"]
|
||||||
|
decoder_input = batch["decoder_input_ids"]
|
||||||
|
|
||||||
|
|
||||||
# generation
|
# generation
|
||||||
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
|
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
|
||||||
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||||
pred_labels.extend(labels)
|
pred_labels.extend(labels)
|
||||||
|
decoder_input_list.extend(decoder_input)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # process predictions
|
# # process predictions
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=False)
|
||||||
|
# Define additional special tokens
|
||||||
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "<SIG>", "<UNIT>", "<DATA_TYPE>"]
|
||||||
|
# Add the additional special tokens to the tokenizer
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# code to get special token ids
|
# code to get special token ids
|
||||||
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
|
sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT><SIG><UNIT><DATA_TYPE>"
|
||||||
# tokens = tokenizer.tokenize(sentence)
|
tokens = tokenizer.tokenize(sentence)
|
||||||
# print("Tokens:", tokens)
|
print("Tokens:", tokens)
|
||||||
# # Get the IDs (integer indices) of specific tokens
|
# Get the IDs (integer indices) of specific tokens
|
||||||
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
||||||
# print("Token IDs:", token_ids)
|
print("Token IDs:", token_ids)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -304,7 +226,13 @@ def extract_seq(tokens, start_value, end_value):
|
||||||
|
|
||||||
return tokens[start_id+1:end_id]
|
return tokens[start_id+1:end_id]
|
||||||
|
|
||||||
|
# %%
|
||||||
|
i = 2
|
||||||
|
print(pred_generations[i])
|
||||||
|
print(extract_seq(pred_generations[i], 32100, 32101))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
def process_tensor_output(tokens):
|
def process_tensor_output(tokens):
|
||||||
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
|
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
|
||||||
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
|
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
|
||||||
|
|
|
@ -0,0 +1,543 @@
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
# Set this to True to run the model on CPU only.
|
||||||
|
USE_CPU_ONLY = False
|
||||||
|
|
||||||
|
flags = os.environ.get("XLA_FLAGS", "")
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
flags += " --xla_force_host_platform_device_count=4" # Simulate 8 devices
|
||||||
|
# Enforce CPU-only execution
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||||
|
else:
|
||||||
|
# GPU flags
|
||||||
|
flags = (
|
||||||
|
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||||
|
'--xla_gpu_triton_gemm_any=True '
|
||||||
|
# '--xla_gpu_enable_async_collectives=true '
|
||||||
|
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||||
|
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||||
|
)
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
|
||||||
|
os.environ["XLA_FLAGS"] = flags
|
||||||
|
os.environ.update({
|
||||||
|
"TOKENIZERS_PARALLELISM" : "false",
|
||||||
|
"CUDA_DEVICE_MAX_CONNECTIONS" : "1",
|
||||||
|
"NCCL_LL128_BUFFSIZE": "-2",
|
||||||
|
"NCCL_LL_BUFFSIZE": "-2",
|
||||||
|
"NCCL_PROTO": "SIMPLE,LL,LL128",
|
||||||
|
"XLA_PYTHON_CLIENT_MEM_FRACTION" : "0.90",
|
||||||
|
# "XLA_PYTHON_CLIENT_PREALLOCATE" : "false"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from functools import partial
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Any, Dict, Tuple, Callable, Sequence, Dict, Union
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
# from jax.experimental.pjit import pjit # superseded by jax.jit
|
||||||
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.sharding import PartitionSpec
|
||||||
|
from ml_collections import ConfigDict
|
||||||
|
import optax
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datasets import Dataset, load_from_disk
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
|
||||||
|
import flax.core
|
||||||
|
|
||||||
|
# model checkpointing and saving utilities
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.training import checkpoints, train_state
|
||||||
|
from flax import struct, serialization
|
||||||
|
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parallel.dataload import DataPrepare
|
||||||
|
|
||||||
|
# for memory tracking
|
||||||
|
# from jax_smi import initialise_tracking
|
||||||
|
# initialise_tracking()
|
||||||
|
|
||||||
|
|
||||||
|
PyTree = Any
|
||||||
|
Metrics = Dict[str, Tuple[jax.Array, ...]]
|
||||||
|
|
||||||
|
if USE_CPU_ONLY:
|
||||||
|
jax.config.update('jax_platform_name', 'cpu')
|
||||||
|
else:
|
||||||
|
jax.config.update("jax_default_matmul_precision", "bfloat16")
|
||||||
|
|
||||||
|
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||||
|
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||||
|
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
## get platform type
|
||||||
|
from jax.extend.backend import get_backend
|
||||||
|
print(get_backend().platform)
|
||||||
|
print(jax.devices())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# config options
|
||||||
|
file_path = '/home/richard/Projects/learn_t5/simple_model/combined_data_t5_retrieval/'
|
||||||
|
save_path = '/home/richard/Projects/06_research/jax_models/model_checkpoints/simple_test/'
|
||||||
|
# file_path = 'combined_data'
|
||||||
|
split_datasets = load_from_disk(file_path)
|
||||||
|
training_size = len(split_datasets['train'])
|
||||||
|
# Store some constant
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 5
|
||||||
|
batch_size = 64
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
warmup_steps = 0
|
||||||
|
learning_rate = 5e-5
|
||||||
|
|
||||||
|
weight_decay = 0.01
|
||||||
|
adam_beta1 = 0.9
|
||||||
|
adam_beta2 = 0.999
|
||||||
|
adam_epsilon = 1e-8
|
||||||
|
label_smoothing_factor = 0.0
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare data
|
||||||
|
# init object
|
||||||
|
# e.g. Config
|
||||||
|
print("preparing data")
|
||||||
|
data_config = ConfigDict(
|
||||||
|
dict(
|
||||||
|
max_length=128,
|
||||||
|
pad_token_id=0,
|
||||||
|
decoder_start_token_id=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataprep = DataPrepare(split_datasets['train'], data_config)
|
||||||
|
# # example usage
|
||||||
|
# # %%
|
||||||
|
seed = 117
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size)
|
||||||
|
batch = next(iter(train_loader))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config as T5ConfigCustom
|
||||||
|
from t5_model.modeling_t5_flax import FlaxT5ForConditionalGeneration as custom_model
|
||||||
|
main_model = custom_model.from_pretrained(
|
||||||
|
"t5-base",
|
||||||
|
dtype=jnp.bfloat16,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
)
|
||||||
|
params = main_model.params
|
||||||
|
# pretrained_params = model.params
|
||||||
|
model = main_model.module
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # testing config hashability
|
||||||
|
# # some explanation:
|
||||||
|
# # The PreTrainedModel class loads a T5Config model that is not hashable because
|
||||||
|
# # it is a complicated class that pretends to be a dataclass.
|
||||||
|
# # The solution is to extract a dict from it, then make a ConfigDict from
|
||||||
|
# # ml_collections library so that we can get values via the "." operator.
|
||||||
|
# # also, we can switch between FrozenConfigDict and ConfigDict, allowing us to
|
||||||
|
# # modify the config before passing to the next layer
|
||||||
|
# from transformers import T5Config
|
||||||
|
# from t5_model.configuration_t5 import FrozenT5Config
|
||||||
|
# from ml_collections import ConfigDict, FrozenConfigDict
|
||||||
|
#
|
||||||
|
# config = T5Config.from_pretrained("t5-base").to_dict()
|
||||||
|
# config.pop('architectures')
|
||||||
|
# config.pop('id2label')
|
||||||
|
# # test if it works
|
||||||
|
# frozen_config = FrozenConfigDict(config)
|
||||||
|
# # test hash
|
||||||
|
# hash(frozen_config)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # print model
|
||||||
|
# rng, input_rng = jax.random.split(rng)
|
||||||
|
# model.tabulate(
|
||||||
|
# input_rng,
|
||||||
|
# input_ids=batch['input_ids'],
|
||||||
|
# attention_mask=batch['attention_mask'],
|
||||||
|
# decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
# decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
# console_kwargs={"force_jupyter": True}
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# create mesh
|
||||||
|
print("creating mesh")
|
||||||
|
device_mesh = mesh_utils.create_device_mesh((2,2))
|
||||||
|
print(device_mesh)
|
||||||
|
|
||||||
|
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
|
||||||
|
print(mesh)
|
||||||
|
|
||||||
|
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
|
||||||
|
return NamedSharding(mesh, pspec, memory_kind="device")
|
||||||
|
|
||||||
|
x_sharding = mesh_sharding(PartitionSpec('data', None)) # replicate across data axis
|
||||||
|
model_sharding=mesh_sharding(PartitionSpec(None, 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# optimizers
|
||||||
|
|
||||||
|
def create_learning_rate_fn(
|
||||||
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||||
|
) -> Callable[[int], jnp.ndarray]:
|
||||||
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||||
|
steps_per_epoch = train_ds_size // train_batch_size
|
||||||
|
num_train_steps = steps_per_epoch * num_train_epochs
|
||||||
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||||
|
decay_fn = optax.linear_schedule(
|
||||||
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||||
|
)
|
||||||
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||||
|
return schedule_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Create learning rate schedule
|
||||||
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
|
training_size,
|
||||||
|
train_batch_size,
|
||||||
|
num_train_epochs,
|
||||||
|
warmup_steps,
|
||||||
|
learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use Optax's "masking" functionality to not apply weight decay
|
||||||
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||||
|
# mask boolean with the same structure as the parameters.
|
||||||
|
# The mask is True for parameters that should be decayed.
|
||||||
|
def decay_mask_fn(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
# find out all LayerNorm parameters
|
||||||
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
||||||
|
layer_norm_named_params = {
|
||||||
|
layer[-2:]
|
||||||
|
for layer_norm_name in layer_norm_candidates
|
||||||
|
for layer in flat_params.keys()
|
||||||
|
if layer_norm_name in "".join(layer).lower()
|
||||||
|
}
|
||||||
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
||||||
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
|
# create adam optimizer
|
||||||
|
adamw = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=adam_beta1,
|
||||||
|
b2=adam_beta2,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
print("compile")
|
||||||
|
|
||||||
|
# enable bf16
|
||||||
|
# enable only for dense, some transformer sections, and shared
|
||||||
|
def create_mask_for_layer_norm(params):
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
mask = {
|
||||||
|
# path: not (
|
||||||
|
# (path[-2] == "layer_norm" and path[-1] == "weight") or
|
||||||
|
# (path[-2] == "final_layer_norm" and path[-1] == "weight") or
|
||||||
|
# (path[-2] == "o" and path[-1] == "kernel")
|
||||||
|
# )
|
||||||
|
# for path in flat_params
|
||||||
|
path: (
|
||||||
|
(path[-2] == "wi" and path[-1] == "weight") or
|
||||||
|
(path[-2] == "wo" and path[-1] == "weight") or
|
||||||
|
(path[-2] == "k" and path[-1] == "kernel") or
|
||||||
|
(path[-2] == "q" and path[-1] == "kernel") or
|
||||||
|
(path[-2] == "v" and path[-1] == "kernel") or
|
||||||
|
(path[-2] == "shared" and path[-1] == "embedding")
|
||||||
|
) for path in flat_params
|
||||||
|
}
|
||||||
|
mask = traverse_util.unflatten_dict(mask)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
# borrowed from transformers modeling_flax_utils
|
||||||
|
def cast_floating_to(params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
||||||
|
def conditional_cast(param):
|
||||||
|
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
||||||
|
param = param.astype(dtype)
|
||||||
|
return param
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
return jax.tree_util.tree_map(conditional_cast, params)
|
||||||
|
|
||||||
|
flat_params = traverse_util.flatten_dict(params)
|
||||||
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
||||||
|
|
||||||
|
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
||||||
|
if masked:
|
||||||
|
flat_params[key] = conditional_cast(flat_params[key])
|
||||||
|
|
||||||
|
return traverse_util.unflatten_dict(flat_params)
|
||||||
|
|
||||||
|
# create init_fn to produce sharded state
|
||||||
|
def init_fn(params, model, optimizer):
|
||||||
|
# do be careful with the model init
|
||||||
|
# imported models might have complicated init methods
|
||||||
|
|
||||||
|
# mask = create_mask_for_layer_norm(params)
|
||||||
|
# override params with bfloat version
|
||||||
|
# params= cast_floating_to(params, jnp.bfloat16, mask)
|
||||||
|
|
||||||
|
state = train_state.TrainState.create( # Create a `TrainState`.
|
||||||
|
apply_fn=model.apply,
|
||||||
|
params=params,
|
||||||
|
tx=optimizer)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
abstract_variables = jax.eval_shape(
|
||||||
|
functools.partial(init_fn, model=model, optimizer=adamw), params)
|
||||||
|
|
||||||
|
# jax.sharding: describes how a jax.Array is laid out across devices
|
||||||
|
state_sharding = nn.get_sharding(abstract_variables, mesh)
|
||||||
|
# print(state_sharding)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# replace the params tree with the new modified tree
|
||||||
|
# create partitions for model
|
||||||
|
from parallel.partitions import set_partitions
|
||||||
|
# set_partitions freezes the params on return
|
||||||
|
model_part_spec = set_partitions(unfreeze(params))
|
||||||
|
# p is already a partition spec
|
||||||
|
model_named_sharding = jax.tree.map(lambda p: mesh_sharding(p), model_part_spec)
|
||||||
|
|
||||||
|
# get pspec for opt_state
|
||||||
|
def get_opt_spec(x):
|
||||||
|
if isinstance(x, dict):
|
||||||
|
return unfreeze(model_named_sharding)
|
||||||
|
# return an empty partspec
|
||||||
|
return mesh_sharding((PartitionSpec()))
|
||||||
|
|
||||||
|
# this function replaces the empty model params spec with the 'model_named_shard'
|
||||||
|
state_sharding = jax.tree.map(
|
||||||
|
get_opt_spec, state_sharding, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
jit_init_fn = jax.jit(
|
||||||
|
init_fn,
|
||||||
|
static_argnames=('model', 'optimizer'), # skip model and optimizer
|
||||||
|
in_shardings=mesh_sharding(PartitionSpec()), # we don't shard params explicitly
|
||||||
|
out_shardings=state_sharding # but returned initialized_state is sharded
|
||||||
|
)
|
||||||
|
initialized_state = jit_init_fn(params, model, adamw)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# train step
|
||||||
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
||||||
|
"""
|
||||||
|
The label smoothing implementation is adapted from Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
||||||
|
"""
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
confidence = 1.0 - label_smoothing_factor
|
||||||
|
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||||
|
normalizing_constant = -(
|
||||||
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||||
|
)
|
||||||
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||||
|
logits = jnp.asarray(logits, dtype=jnp.float32)
|
||||||
|
logits = logits.astype(jnp.float32)
|
||||||
|
soft_labels = soft_labels.astype(jnp.float32)
|
||||||
|
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
||||||
|
loss = loss - normalizing_constant
|
||||||
|
|
||||||
|
# ignore padded tokens from loss
|
||||||
|
loss = loss * padding_mask
|
||||||
|
loss = loss.mean()
|
||||||
|
# num_labels = padding_mask.mean()
|
||||||
|
return loss # , num_labels
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# single device code annotated with jax.jit
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
# state is state_sharding initialized from init_fn
|
||||||
|
# x_sharding is data sharded explicitly later
|
||||||
|
in_shardings=(state_sharding, x_sharding),
|
||||||
|
out_shardings=(state_sharding, mesh_sharding(PartitionSpec())),
|
||||||
|
donate_argnames=('state'),
|
||||||
|
)
|
||||||
|
def train_step(state, batch):
|
||||||
|
|
||||||
|
label_smoothing_factor=0.0
|
||||||
|
# dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||||
|
def compute_loss(params, batch):
|
||||||
|
# check constraints
|
||||||
|
# frozen dict not allowed as sharding object
|
||||||
|
params = jax.lax.with_sharding_constraint(params, unfreeze(model_named_sharding))
|
||||||
|
batch = jax.lax.with_sharding_constraint(batch, x_sharding)
|
||||||
|
logits = state.apply_fn(
|
||||||
|
{'params': params},
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
decoder_input_ids=batch['decoder_input_ids'],
|
||||||
|
decoder_attention_mask=batch['decoder_attention_mask'],
|
||||||
|
)[0] # zero because output is some structure, where first is the logit
|
||||||
|
# use labels here
|
||||||
|
# loss, num_labels = loss_fn(
|
||||||
|
loss = loss_fn(
|
||||||
|
logits,
|
||||||
|
batch["labels"],
|
||||||
|
batch["decoder_attention_mask"],
|
||||||
|
label_smoothing_factor)
|
||||||
|
return loss # , num_labels
|
||||||
|
|
||||||
|
# compute gradients through computational graph
|
||||||
|
# allow values to pass through
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=False)
|
||||||
|
(loss), grad = grad_fn(state.params, batch)
|
||||||
|
# num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
|
||||||
|
new_state = state.apply_gradients(grads=grad)
|
||||||
|
with jax.named_scope("sync_metrics"):
|
||||||
|
step_metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
|
|
||||||
|
return new_state, step_metrics
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# explore data sharding
|
||||||
|
sharded_batch = next(iter(train_loader))
|
||||||
|
sharded_batch = jax.device_put(sharded_batch, x_sharding)
|
||||||
|
# jax.debug.visualize_array_sharding(sharded_batch['input_ids'])
|
||||||
|
# jax.debug.visualize_array_sharding(initialized_state.params['shared']['embedding'])
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# # prep 1 step
|
||||||
|
# print("1 step for jit-ting")
|
||||||
|
# with mesh:
|
||||||
|
# state, metrics = train_step(initialized_state, sharded_batch)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# tr
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {training_size}")
|
||||||
|
print(f" Num Epochs = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
print(f" Total optimization steps = {total_train_steps}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# jax.profiler.start_trace("./traces")
|
||||||
|
|
||||||
|
|
||||||
|
print("*" * 10)
|
||||||
|
print("training start")
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
train_time = 0
|
||||||
|
state = initialized_state
|
||||||
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
train_start = time.time()
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
|
train_metrics = []
|
||||||
|
steps_per_epoch = training_size // train_batch_size
|
||||||
|
train_loader = dataprep.data_loader(rng, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||||
|
batch = next(train_loader)
|
||||||
|
batch = jax.device_put(batch, x_sharding)
|
||||||
|
with mesh:
|
||||||
|
state, train_metric = train_step(state, batch)
|
||||||
|
|
||||||
|
# train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
|
||||||
|
# this is for more accurate time stats, but slows down training
|
||||||
|
# train_metric['loss'].block_until_ready()
|
||||||
|
train_time = time.time() - train_start
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
epochs.write(
|
||||||
|
f"Epoch... ({epoch + 1}/{num_epochs} | "
|
||||||
|
f"Loss: {train_metric['loss']}, "
|
||||||
|
f"Learning Rate:{train_metric['learning_rate']}, "
|
||||||
|
f"Last train time: {train_time})"
|
||||||
|
)
|
||||||
|
# jax.profiler.stop_trace()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# try out
|
||||||
|
gather_state = jax.device_get(state)
|
||||||
|
gather_batch = jax.device_get(batch)
|
||||||
|
logits = gather_state.apply_fn(
|
||||||
|
{'params': gather_state.params},
|
||||||
|
input_ids=gather_batch['input_ids'],
|
||||||
|
attention_mask=gather_batch['attention_mask'],
|
||||||
|
decoder_input_ids=gather_batch['decoder_input_ids'],
|
||||||
|
decoder_attention_mask=gather_batch['decoder_attention_mask'],
|
||||||
|
)[0] # zero because output is some structure, where first is the logit
|
||||||
|
|
||||||
|
probs = nn.softmax(logits, axis=-1)
|
||||||
|
predicted = jnp.argmax(probs, axis=-1)
|
||||||
|
predicted[1]
|
||||||
|
|
||||||
|
# %%
|
||||||
|
main_model = custom_model.from_pretrained('t5-base')
|
||||||
|
output_dir = save_path
|
||||||
|
|
||||||
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
if jax.process_index() == 0:
|
||||||
|
params = jax.device_get(state.params)
|
||||||
|
main_model.save_pretrained(output_dir, params=params)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
|
@ -0,0 +1,360 @@
|
||||||
|
|
||||||
|
# ---
|
||||||
|
# jupyter:
|
||||||
|
# jupytext:
|
||||||
|
# formats: ipynb,py:percent
|
||||||
|
# text_representation:
|
||||||
|
# extension: .py
|
||||||
|
# format_name: percent
|
||||||
|
# format_version: '1.3'
|
||||||
|
# jupytext_version: 1.16.4
|
||||||
|
# ---
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # prediction code
|
||||||
|
# ## import and process test data
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# import libraries
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import optax
|
||||||
|
import numpy as np
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Optional
|
||||||
|
import math
|
||||||
|
|
||||||
|
# jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
||||||
|
jax.config.update("jax_default_matmul_precision", "high")
|
||||||
|
|
||||||
|
jax.config.update("jax_enable_x64", False)
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
from datasets import Dataset
|
||||||
|
import evaluate
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
|
||||||
|
from flax import jax_utils, traverse_util
|
||||||
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
|
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# data_path = f"../make_data/select_db/data_mapping_filtered.csv"
|
||||||
|
# data_path = f"../make_data_2/select_db/dataset/1/train_all.csv"
|
||||||
|
data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/test.csv'
|
||||||
|
# data_path = f'/home/richard/Projects/06_research/hipom_data_mapping/data_preprocess/dataset/1/train_all.csv'
|
||||||
|
|
||||||
|
# Ensure to include 'ships_idx' in the fields list
|
||||||
|
fields = ['ships_idx', 'tag_name', 'tag_description', 'thing', 'property', 'unit']
|
||||||
|
|
||||||
|
# Load the dataset
|
||||||
|
df = pd.read_csv(data_path, skipinitialspace=True, usecols=fields)
|
||||||
|
|
||||||
|
def process_df(df):
|
||||||
|
output_list = [{
|
||||||
|
'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC>",
|
||||||
|
# 'input': f"<DESC>{row['tag_description']}<DESC>",
|
||||||
|
# 'input': f"<NAME>{row['tag_name']}<NAME><DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||||||
|
# 'input': f"<DESC>{row['tag_description']}<DESC><UNIT>{row['unit']}<UNIT>",
|
||||||
|
'output': f"<THING_START>{row['thing']}<THING_END><PROPERTY_START>{row['property']}<PROPERTY_END>",
|
||||||
|
# 'answer': f"{row['thing']} {row['property']}",
|
||||||
|
# 'answer_thing': row['thing'],
|
||||||
|
# 'answer_property': row['property'],
|
||||||
|
} for _, row in df.iterrows()]
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
|
||||||
|
# takes 1 minute to run without batching
|
||||||
|
test_dataset = Dataset.from_list(process_df(df))
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Load model for attributes
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# load model
|
||||||
|
model_name_or_path = "./t5_80_1" # Replace with your specific model name
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Tokenizer
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# prepare tokenizer
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
tokenizer = T5TokenizerFast.from_pretrained("t5-base", return_tensors="np", clean_up_tokenization_spaces=True)
|
||||||
|
# Define additional special tokens
|
||||||
|
additional_special_tokens = ["<THING_START>", "<THING_END>", "<PROPERTY_START>", "<PROPERTY_END>", "<NAME>", "<DESC>", "SIG", "UNIT", "DATA_TYPE"]
|
||||||
|
# Add the additional special tokens to the tokenizer
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
|
||||||
|
max_length = 86
|
||||||
|
|
||||||
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
|
||||||
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
||||||
|
|
||||||
|
# given a dataset entry, run it through the tokenizer
|
||||||
|
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
||||||
|
def preprocess_function(example):
|
||||||
|
inputs = example['input']
|
||||||
|
targets = example['output']
|
||||||
|
# text_target sets the corresponding label to inputs
|
||||||
|
# there is no need to create a separate 'labels'
|
||||||
|
model_inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
max_length=max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
labels = tokenizer(
|
||||||
|
text_target=targets,
|
||||||
|
max_length=max_length,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
decoder_input_ids = shift_tokens_right_fn(
|
||||||
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
||||||
|
|
||||||
|
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
||||||
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
# map maps function to each "row" in the dataset
|
||||||
|
# aka the data in the immediate nesting
|
||||||
|
test_dataset = test_dataset.map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=1,
|
||||||
|
remove_columns=test_dataset.column_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||||
|
"""
|
||||||
|
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||||
|
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||||
|
"""
|
||||||
|
if shuffle:
|
||||||
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||||
|
batch_idx = np.asarray(batch_idx)
|
||||||
|
else:
|
||||||
|
batch_idx = np.arange(len(dataset))
|
||||||
|
|
||||||
|
if drop_last:
|
||||||
|
steps_per_epoch = len(dataset) // batch_size
|
||||||
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||||
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||||
|
else:
|
||||||
|
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||||
|
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||||
|
|
||||||
|
for idx in batch_idx:
|
||||||
|
batch = dataset[idx]
|
||||||
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
|
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # model generation
|
||||||
|
|
||||||
|
# %%
|
||||||
|
seed = 117
|
||||||
|
num_epochs = 80
|
||||||
|
batch_size = 96
|
||||||
|
num_train_epochs = num_epochs
|
||||||
|
per_device_train_batch_size = batch_size
|
||||||
|
train_batch_size = per_device_train_batch_size * jax.device_count()
|
||||||
|
per_device_eval_batch_size = batch_size
|
||||||
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||||
|
steps_per_epoch = len(test_dataset) // train_batch_size
|
||||||
|
total_train_steps = steps_per_epoch * num_epochs
|
||||||
|
|
||||||
|
num_beams = 1
|
||||||
|
val_max_target_length = 128
|
||||||
|
|
||||||
|
predict_with_generate = True
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize our training
|
||||||
|
rng = jax.random.PRNGKey(seed)
|
||||||
|
rng, dropout_rng = jax.random.split(rng)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# reload model to prevent leakage of variables
|
||||||
|
# load model
|
||||||
|
model_name_or_path = "t5_80_1_bf16" # Replace with your specific model name
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure model.params is properly initialized (this is just an example)
|
||||||
|
# Normally you would get this from a model initialization call with dummy input
|
||||||
|
params = model.params
|
||||||
|
# ensure full size floats
|
||||||
|
params_f16 = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), params)
|
||||||
|
# we need to replicate model over devices
|
||||||
|
replicated_params = jax.device_put_replicated(params_f16, jax.devices())
|
||||||
|
|
||||||
|
|
||||||
|
# Define generation function
|
||||||
|
max_length = (
|
||||||
|
val_max_target_length if val_max_target_length is not None else model.config.max_length
|
||||||
|
)
|
||||||
|
num_beams = num_beams if num_beams is not None else model.config.num_beams
|
||||||
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
||||||
|
|
||||||
|
def generate_step(params, batch):
|
||||||
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], params=params, **gen_kwargs)
|
||||||
|
return output_ids.sequences
|
||||||
|
|
||||||
|
# Create parallel version of the train and eval step
|
||||||
|
p_generate_step = jax.pmap(generate_step, "batch")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
pred_generations = []
|
||||||
|
pred_labels = []
|
||||||
|
|
||||||
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
|
||||||
|
pred_loader = data_loader(input_rng, test_dataset, eval_batch_size, drop_last=False)
|
||||||
|
pred_steps = math.ceil(len(test_dataset) / eval_batch_size)
|
||||||
|
|
||||||
|
print("***** Running training *****")
|
||||||
|
print(f" Num examples = {len(test_dataset)}")
|
||||||
|
print(f" Num steps = {num_epochs}")
|
||||||
|
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
||||||
|
print(f" Total test batch size (w. parallel & distributed) = {train_batch_size}")
|
||||||
|
|
||||||
|
|
||||||
|
for _ in tqdm(range(pred_steps), desc="Predicting..."):
|
||||||
|
# Model forward
|
||||||
|
batch = next(pred_loader)
|
||||||
|
labels = batch["labels"]
|
||||||
|
|
||||||
|
# generation
|
||||||
|
generated_ids = pad_shard_unpad(p_generate_step)(replicated_params, batch)
|
||||||
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||||
|
pred_labels.extend(labels)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # process predictions
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# code to get special token ids
|
||||||
|
# sentence = "<THING_START><THING_END><PROPERTY_START><PROPERTY_END><NAME><DESC><DESC><UNIT>"
|
||||||
|
# tokens = tokenizer.tokenize(sentence)
|
||||||
|
# print("Tokens:", tokens)
|
||||||
|
# # Get the IDs (integer indices) of specific tokens
|
||||||
|
# token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
||||||
|
# print("Token IDs:", token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# extract sequence and decode
|
||||||
|
def extract_seq(tokens, start_value, end_value):
|
||||||
|
if start_value not in tokens or end_value not in tokens:
|
||||||
|
return None # Or handle this case according to your requirements
|
||||||
|
start_id = np.where(tokens == start_value)[0][0]
|
||||||
|
end_id = np.where(tokens == end_value)[0][0]
|
||||||
|
|
||||||
|
return tokens[start_id+1:end_id]
|
||||||
|
|
||||||
|
|
||||||
|
def process_tensor_output(tokens):
|
||||||
|
thing_seq = extract_seq(tokens, 32100, 32101) # 32100 = <THING_START>, 32101 = <THING_END>
|
||||||
|
property_seq = extract_seq(tokens, 32102, 32103) # 32102 = <PROPERTY_START>, 32103 = <PROPERTY_END>
|
||||||
|
p_thing = None
|
||||||
|
p_property = None
|
||||||
|
if (thing_seq is not None):
|
||||||
|
p_thing = tokenizer.decode(thing_seq, skip_special_tokens=False) # retain <COLLIDE>
|
||||||
|
if (property_seq is not None):
|
||||||
|
p_property = tokenizer.decode(property_seq, skip_special_tokens=False) # retain <COLLIDE>
|
||||||
|
return p_thing, p_property
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# decode prediction labels
|
||||||
|
def decode_preds(tokens_list):
|
||||||
|
thing_prediction_list = []
|
||||||
|
property_prediction_list = []
|
||||||
|
for tokens in tokens_list:
|
||||||
|
p_thing, p_property = process_tensor_output(tokens)
|
||||||
|
thing_prediction_list.append(p_thing)
|
||||||
|
property_prediction_list.append(p_property)
|
||||||
|
return thing_prediction_list, property_prediction_list
|
||||||
|
|
||||||
|
thing_prediction_list, property_prediction_list = decode_preds(pred_generations)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# add labels too
|
||||||
|
thing_actual_list, property_actual_list = decode_preds(pred_labels)
|
||||||
|
|
||||||
|
# Convert the list to a Pandas DataFrame
|
||||||
|
df = pd.DataFrame({'p_thing': thing_prediction_list,
|
||||||
|
'p_property': property_prediction_list,
|
||||||
|
'thing': thing_actual_list,
|
||||||
|
'property' : property_actual_list})
|
||||||
|
|
||||||
|
df['p_thing_correct'] = df['p_thing'] == df['thing']
|
||||||
|
df['p_property_correct'] = df['p_property'] == df['property']
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print("thing accuracy", sum(df['p_thing_correct'])/len(df))
|
||||||
|
print("property accuracy", sum(df['p_property_correct'])/len(df))
|
||||||
|
print("total accuracy", sum(df['p_property_correct'] & df['p_thing_correct'])/len(df))
|
||||||
|
# %%
|
||||||
|
df[~df["p_property_correct"]]
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df['p_thing']
|
||||||
|
# %%
|
||||||
|
# Save the DataFrame as a Parquet file (using pyarrow or fastparquet)
|
||||||
|
# df.to_parquet("exports/output_file.parquet", engine="pyarrow") # or use engine="fastparquet"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue