-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
275 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
## Data | ||
|
||
We have open-sourced our dataset of 32,555 pairs, which includes Chinese data. The dataset is available [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). The details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md#prompt-refiner). | ||
|
||
In fact, it is a JSON file with the following structure. | ||
|
||
``` | ||
[ | ||
{ | ||
"instruction": "Refine the sentence: \"A newly married couple sharing a piece of there wedding cake.\" to contain subject description, action, scene description. (Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. Make sure it is a fluent sentence, not nonsense.", | ||
"input": "", | ||
"output": "The newlywed couple, dressed in elegant attire..." | ||
}, | ||
... | ||
] | ||
``` | ||
|
||
## Train | ||
|
||
`--data_path` is the path to the prepared JSON file. | ||
`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files. | ||
`--lora_out_path` is the path where the LoRA model will be saved. | ||
|
||
``` | ||
cd opensora/models/prompt_refiner | ||
CUDA_VISIBLE_DEVICES=0 python train.py \ | ||
--data_path path/to/data.json \ | ||
--model_path path/to/llama_model \ | ||
--lora_out_path path/to/save/lora_model | ||
``` | ||
|
||
## Merge | ||
|
||
`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files. | ||
`--lora_in_path` is the directory containing the pre-trained LoRA model. | ||
`--lora_out_path` is the path for the merged model. | ||
|
||
``` | ||
cd opensora/models/prompt_refiner | ||
CUDA_VISIBLE_DEVICES=0 python merge.py \ | ||
--base_path path/to/llama_model \ | ||
--lora_in_path path/to/save/lora_model \ | ||
--lora_out_path path/to/save/merge_model | ||
``` | ||
|
||
## Inference | ||
|
||
`--model_path` is the directory containing the weights (LLaMA 3.1 or merged Lora weight), including `config.json` and some weight files. | ||
`--prompt` is the text you want to input, which will be refined. | ||
|
||
``` | ||
cd opensora/models/prompt_refiner | ||
CUDA_VISIBLE_DEVICES=0 python merge.py \ | ||
--mode_path path/to/data.json \ | ||
--prompt path/to/save/lora_model | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
import torch | ||
from tqdm import tqdm | ||
import argparse | ||
|
||
def get_output(prompt): | ||
template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ | ||
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ | ||
"Make sure it is a fluent sentence, not nonsense." | ||
prompt = template.format(prompt) | ||
messages = [ | ||
{"role": "system", "content": "You are a caption refiner."}, | ||
{"role": "user", "content": prompt} | ||
] | ||
|
||
input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
model_inputs = tokenizer([input_ids], return_tensors="pt").to(device) | ||
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512) | ||
generated_ids = [ | ||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | ||
] | ||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
print('\nInput\n:', prompt) | ||
print('\nOutput\n:', response) | ||
return response | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--mode_path", type=str, default="llama3_8B_lora_merged_cn") | ||
parser.add_argument("--prompt", type=str, default='a dog is running.') | ||
args = parser.parse_args() | ||
return args | ||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
device = torch.device('cuda') | ||
tokenizer = AutoTokenizer.from_pretrained(args.mode_path, trust_remote_code=True) | ||
model = AutoModelForCausalLM.from_pretrained(args.mode_path,torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval() | ||
|
||
response = get_output(args.prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from peft import PeftModel | ||
import torch | ||
import argparse | ||
|
||
|
||
def get_lora_model(base_model_path, lora_model_input_path, lora_model_output_path): | ||
model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map="auto",trust_remote_code=True) | ||
model = PeftModel.from_pretrained(model, lora_model_input_path) | ||
merged_model = model.merge_and_unload() | ||
merged_model.save_pretrained(lora_model_output_path, safe_serialization=True) | ||
print("Merge lora to base model") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True) | ||
tokenizer.save_pretrained(lora_model_output_path) | ||
print("Save tokenizer") | ||
|
||
def get_model_result(base_model_path, fintune_model_path): | ||
tokenizer = AutoTokenizer.from_pretrained(base_model_path) | ||
device = "cuda" | ||
|
||
fintune_model = AutoModelForCausalLM.from_pretrained( | ||
fintune_model_path, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
).eval() | ||
|
||
base_model = AutoModelForCausalLM.from_pretrained( | ||
base_model_path, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
).eval() | ||
|
||
template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ | ||
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ | ||
"Make sure it is a fluent sentence, not nonsense." | ||
|
||
prompt = "a dog和一只猫" | ||
prompt = template.format(prompt) | ||
messages = [ | ||
{"role": "system", "content": "You are a caption refiner."}, | ||
{"role": "user", "content": prompt} | ||
] | ||
text = tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False, | ||
add_generation_prompt=True | ||
) | ||
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(device) | ||
|
||
def get_result(model_inputs, model): | ||
generated_ids = model.generate( | ||
model_inputs.input_ids, | ||
max_new_tokens=512, | ||
eos_token_id=tokenizer.get_vocab()["<|eot_id|>"] | ||
) | ||
generated_ids = [ | ||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | ||
] | ||
|
||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | ||
return response | ||
|
||
base_model_response = get_result(model_inputs, base_model) | ||
fintune_model_response = get_result(model_inputs, fintune_model) | ||
print("\nInput\n", prompt) | ||
print("\nResult before fine-tune:\n", base_model_response) | ||
print("\nResult after fine-tune:\n", fintune_model_response) | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--base_path", type=str, default="Meta-Llama-3___1-8B-Instruct") | ||
parser.add_argument("--lora_in_path", type=str, default="llama3_1_instruct_lora/checkpoint-1008") | ||
parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora/llama3_8B_lora_merged_cn") | ||
args = parser.parse_args() | ||
return args | ||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
get_lora_model(args.base_path, args.lora_in_path, args.lora_out_path) | ||
get_model_result(args.base_path, args.lora_out_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from datasets import Dataset | ||
import pandas as pd | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig | ||
from peft import LoraConfig, TaskType, get_peft_model | ||
import torch | ||
import argparse | ||
|
||
ins = "Refine the sentence to contain subject description, action, scene description. " \ | ||
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ | ||
"Make sure it is a fluent sentence, not nonsense." | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--data_path", type=str, default='refine_32255.json') | ||
parser.add_argument("--model_path", type=str, default='Meta-Llama-3___1-8B-Instruct') | ||
parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora") | ||
args = parser.parse_args() | ||
return args | ||
|
||
args = parse_args() | ||
|
||
|
||
df = pd.read_json(args.data_path) | ||
ds = Dataset.from_pandas(df) | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
def process_func(example): | ||
MAX_LENGTH = 2048 | ||
input_ids, attention_mask, labels = [], [], [] | ||
instruction = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a caption refiner.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens | ||
response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False) | ||
input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id] | ||
attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1] | ||
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id] | ||
if len(input_ids) > MAX_LENGTH: | ||
input_ids = input_ids[:MAX_LENGTH] | ||
attention_mask = attention_mask[:MAX_LENGTH] | ||
labels = labels[:MAX_LENGTH] | ||
return { | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"labels": labels | ||
} | ||
|
||
tokenized_id = ds.map(process_func, remove_columns=ds.column_names) | ||
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto",torch_dtype=torch.bfloat16) | ||
print(model) | ||
model.enable_input_require_grads() | ||
|
||
config = LoraConfig( | ||
task_type=TaskType.CAUSAL_LM, | ||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | ||
inference_mode=False, | ||
r=64, | ||
lora_alpha=64, | ||
lora_dropout=0.1 | ||
) | ||
print(config) | ||
|
||
model = get_peft_model(model, config) | ||
model.print_trainable_parameters() | ||
|
||
args = TrainingArguments( | ||
output_dir=args.lora_out_path, | ||
per_device_train_batch_size=32, | ||
gradient_accumulation_steps=1, | ||
logging_steps=1, | ||
num_train_epochs=1, | ||
save_steps=20, | ||
dataloader_num_workers=4, | ||
learning_rate=1.5e-4, | ||
warmup_ratio=0.1, | ||
save_on_each_node=True, | ||
gradient_checkpointing=True, | ||
report_to='wandb', | ||
) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
args=args, | ||
train_dataset=tokenized_id, | ||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), | ||
) | ||
|
||
trainer.train() |