2025-01-14 17:34:17 +09:00
|
|
|
# %%
|
|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# Load model and tokenizer
|
|
|
|
# model_name = "bigscience/bloom-7b1" # Replace with your model
|
|
|
|
model_name = "google/flan-t5-large"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
# Automatically map model layers to available GPUs
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
|
model_name,
|
|
|
|
device_map="auto", # Automatically split across multiple GPUs
|
|
|
|
torch_dtype="auto" # Use FP16 if available
|
|
|
|
)
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# Prepare input
|
|
|
|
text = "The quick brown fox jumps over the lazy dog."
|
|
|
|
inputs = tokenizer(text, return_tensors="pt")
|
|
|
|
inputs = inputs.to("cuda")
|
|
|
|
|
|
|
|
# Generate output
|
|
|
|
outputs = model.generate(inputs["input_ids"], max_length=50)
|
|
|
|
|
|
|
|
# Decode and print result
|
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
|
# %%
|
|
|
|
# %%
|
|
|
|
# Prepare input
|
|
|
|
|
|
|
|
def generate_acronym(text):
|
|
|
|
|
|
|
|
# Define prompt
|
2025-01-15 20:09:15 +09:00
|
|
|
# prompt = f"Imagine you are a diverse database. Given the following: '{text}', please suggest to me 5 possible variations. Give 5."
|
|
|
|
prompt = f"Give me a list of 10 historical product names related to: '{text}'. Format the output in a list, like this 1. Item, 2. Item, 3. ..."
|
2025-01-14 17:34:17 +09:00
|
|
|
|
|
|
|
# Generate acronym
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
inputs = inputs.to("cuda")
|
|
|
|
outputs = model.generate(
|
|
|
|
inputs["input_ids"],
|
2025-01-15 20:09:15 +09:00
|
|
|
max_length=200,
|
|
|
|
do_sample=True,
|
|
|
|
top_k=50,
|
|
|
|
temperature=0.8)
|
|
|
|
# no_repeat_ngram_size=3)
|
2025-01-14 17:34:17 +09:00
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
# %%
|
|
|
|
# Example usage
|
|
|
|
# text = "Advanced Data Analytics Platform"
|
2025-01-15 20:09:15 +09:00
|
|
|
text = "windows desktop"
|
2025-01-14 17:34:17 +09:00
|
|
|
acronym = generate_acronym(text)
|
2025-01-15 20:09:15 +09:00
|
|
|
print(f"Generation: {acronym}")
|
2025-01-14 17:34:17 +09:00
|
|
|
# %%
|