Skip to content

Commit

Permalink
release prompt refiner
Browse files Browse the repository at this point in the history
  • Loading branch information
LinB203 committed Oct 22, 2024
1 parent 390de3f commit c41050a
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 3 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Coming soon...

| Version | Architecture | Diffusion Model | CausalVideoVAE | Data | Prompt Refiner |
|:---|:---|:---|:---|:---|:---|
| v1.3.0 | 3D | [Anysize in 93x640x640](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640)[3], more checkpoints are coming soon | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae)| - | [checkpoint](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner)| |
| v1.3.0 | 3D | [Anysize in 93x640x640](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640)[3], more checkpoints are coming soon | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae)| [prompt_refiner](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner) | [checkpoint](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner)| |
| v1.2.0 | 3D | [93x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x720p), [29x720p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x720p)[1], [93x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p)[1,2], [29x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/29x480p), [1x480p](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/1x480p), [93x480p_i2v](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/93x480p_i2v) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main/vae)| [Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) | - |
| v1.1.0 | 2+1D | [221x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/221x512x512), [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/65x512x512) |[Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main/vae) |[Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0)| - |
| v1.0.0 | 2+1D | [65x512x512](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512), [65x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256), [17x256x256](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [Anysize](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/vae) | [Data and Annotations](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)| - |
Expand Down Expand Up @@ -122,12 +122,16 @@ pip install -e .
pip install -e '.[dev]'
```

# 🗝️ Training & Validating
# 🗝️ Training & Inferencing

## 🗜️ CausalVideoVAE

The data preparation, training, inferencing and evaluation can be found [here](docs/VAE.md)

## 📖 Prompt Refiner

The data preparation, training, inferencing can be found [here](docs/Prompt_Refiner.md)

## 📜 Text-to-Video

The data preparation, training and inferencing can be found [here](docs/T2V.md)
Expand Down
56 changes: 56 additions & 0 deletions docs/Prompt_Refiner.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
## Data

We have open-sourced our dataset of 32,555 pairs, which includes Chinese data. The dataset is available [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner). The details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md#prompt-refiner).

In fact, it is a JSON file with the following structure.

```
[
{
"instruction": "Refine the sentence: \"A newly married couple sharing a piece of there wedding cake.\" to contain subject description, action, scene description. (Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. Make sure it is a fluent sentence, not nonsense.",
"input": "",
"output": "The newlywed couple, dressed in elegant attire..."
},
...
]
```

## Train

`--data_path` is the path to the prepared JSON file.
`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files.
`--lora_out_path` is the path where the LoRA model will be saved.

```
cd opensora/models/prompt_refiner
CUDA_VISIBLE_DEVICES=0 python train.py \
--data_path path/to/data.json \
--model_path path/to/llama_model \
--lora_out_path path/to/save/lora_model
```

## Merge

`--model_path` is the directory containing the LLaMA 3.1 weights, including `config.json` and some weight files.
`--lora_in_path` is the directory containing the pre-trained LoRA model.
`--lora_out_path` is the path for the merged model.

```
cd opensora/models/prompt_refiner
CUDA_VISIBLE_DEVICES=0 python merge.py \
--base_path path/to/llama_model \
--lora_in_path path/to/save/lora_model \
--lora_out_path path/to/save/merge_model
```

## Inference

`--model_path` is the directory containing the weights (LLaMA 3.1 or merged Lora weight), including `config.json` and some weight files.
`--prompt` is the text you want to input, which will be refined.

```
cd opensora/models/prompt_refiner
CUDA_VISIBLE_DEVICES=0 python merge.py \
--mode_path path/to/data.json \
--prompt path/to/save/lora_model
```
3 changes: 2 additions & 1 deletion docs/Report-v1.3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ In version 1.3.0, Open-Sora-Plan introduced the following five key features:
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available.
- Code: All training scripts and sample scripts.
- Model: Both Diffusion Model and CasualVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0).
- Data: The data of prompt refiner is [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner).

## Gallery

Expand Down Expand Up @@ -158,7 +159,7 @@ conceive some additional actions to make the sentence more dynamic,
make sure it is a fluent sentence, not nonsense.
```

Finally, we performed LoRA fine-tuning using [LLaMa 3.1](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), completing the training in just 30 minutes with a single H100. We fine-tuned for only 1 epoch, using a batch size of 32 and a LoRA rank of 64. The log can be found [here](https://api.wandb.ai/links/1471742727-Huawei/p5xmkft5).
Finally, we performed LoRA fine-tuning using [LLaMa 3.1](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), completing the training in just 30 minutes with a single H100. We fine-tuned for only 1 epoch, using a batch size of 32 and a LoRA rank of 64. The log can be found [here](https://api.wandb.ai/links/1471742727-Huawei/p5xmkft5). We open-sourced the data [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner).

### Data Construction

Expand Down
40 changes: 40 additions & 0 deletions opensora/models/prompt_refiner/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
import argparse

def get_output(prompt):
template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense."
prompt = template.format(prompt)
messages = [
{"role": "system", "content": "You are a caption refiner."},
{"role": "user", "content": prompt}
]

input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([input_ids], return_tensors="pt").to(device)
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print('\nInput\n:', prompt)
print('\nOutput\n:', response)
return response

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--mode_path", type=str, default="llama3_8B_lora_merged_cn")
parser.add_argument("--prompt", type=str, default='a dog is running.')
args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(args.mode_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.mode_path,torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()

response = get_output(args.prompt)
83 changes: 83 additions & 0 deletions opensora/models/prompt_refiner/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import argparse


def get_lora_model(base_model_path, lora_model_input_path, lora_model_output_path):
model = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, device_map="auto",trust_remote_code=True)
model = PeftModel.from_pretrained(model, lora_model_input_path)
merged_model = model.merge_and_unload()
merged_model.save_pretrained(lora_model_output_path, safe_serialization=True)
print("Merge lora to base model")

tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
tokenizer.save_pretrained(lora_model_output_path)
print("Save tokenizer")

def get_model_result(base_model_path, fintune_model_path):
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
device = "cuda"

fintune_model = AutoModelForCausalLM.from_pretrained(
fintune_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()

base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()

template = "Refine the sentence: \"{}\" to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense."

prompt = "a dog和一只猫"
prompt = template.format(prompt)
messages = [
{"role": "system", "content": "You are a caption refiner."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)

model_inputs = tokenizer([text], return_tensors="pt").to(device)

def get_result(model_inputs, model):
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=512,
eos_token_id=tokenizer.get_vocab()["<|eot_id|>"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response

base_model_response = get_result(model_inputs, base_model)
fintune_model_response = get_result(model_inputs, fintune_model)
print("\nInput\n", prompt)
print("\nResult before fine-tune:\n", base_model_response)
print("\nResult after fine-tune:\n", fintune_model_response)

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_path", type=str, default="Meta-Llama-3___1-8B-Instruct")
parser.add_argument("--lora_in_path", type=str, default="llama3_1_instruct_lora/checkpoint-1008")
parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora/llama3_8B_lora_merged_cn")
args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
get_lora_model(args.base_path, args.lora_in_path, args.lora_out_path)
get_model_result(args.base_path, args.lora_out_path)
88 changes: 88 additions & 0 deletions opensora/models/prompt_refiner/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig
from peft import LoraConfig, TaskType, get_peft_model
import torch
import argparse

ins = "Refine the sentence to contain subject description, action, scene description. " \
"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \
"Make sure it is a fluent sentence, not nonsense."

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default='refine_32255.json')
parser.add_argument("--model_path", type=str, default='Meta-Llama-3___1-8B-Instruct')
parser.add_argument("--lora_out_path", type=str, default="llama3_1_instruct_lora")
args = parser.parse_args()
return args

args = parse_args()


df = pd.read_json(args.data_path)
ds = Dataset.from_pandas(df)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

def process_func(example):
MAX_LENGTH = 2048
input_ids, attention_mask, labels = [], [], []
instruction = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a caption refiner.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{example['instruction'] + example['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens
response = tokenizer(f"{example['output']}<|eot_id|>", add_special_tokens=False)
input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}

tokenized_id = ds.map(process_func, remove_columns=ds.column_names)


model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto",torch_dtype=torch.bfloat16)
print(model)
model.enable_input_require_grads()

config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=False,
r=64,
lora_alpha=64,
lora_dropout=0.1
)
print(config)

model = get_peft_model(model, config)
model.print_trainable_parameters()

args = TrainingArguments(
output_dir=args.lora_out_path,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
logging_steps=1,
num_train_epochs=1,
save_steps=20,
dataloader_num_workers=4,
learning_rate=1.5e-4,
warmup_ratio=0.1,
save_on_each_node=True,
gradient_checkpointing=True,
report_to='wandb',
)

trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_id,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

trainer.train()

0 comments on commit c41050a

Please sign in to comment.