Skip to content

lwb2099/stable_diffusion_pytorch

Repository files navigation

Stable Diffusion

Implementation of stable diffusion in pytorch

This repo implements the text-to-image model stable diffusion in pytorch. The code uses pretrained CLIPTextModel and CLIPToeknizer from huggingface with the rest models trained from scratch on poloclub/diffusiondb/2m_first_10k dataset.

Prepare and running

to build env for your run, just simply create a conda env with python>3.7 recommended, and install the packages:

git clone https://github.com/lwb2099/stable_diffusion_pytorch.git
pip install -r requirements.txt

after which you can simply run by passing

accelerate launch --config_file place_for_your_yaml_file train_unet.py --train_args

I pushed my vscode launch.json so that you can modify command line arguments more easily.

Structure

The structure of the code has drew insight from a few awesome repositories: fairseq, transformers and it should looks like this:

    |- data
    |- dataset
    |-- pretrained
    |- model
    |-- unet
    |-- autoencoder
    |- scripts
    |- stable_diffusion
    |-- config
    |-- models
    |-- modules
    |- test
    |- utils

data stores downloaded dataset in data/dataset and pretrained CLIP model from huggingface/models in data/pretrained.

model is used to store training ckpts with name "checkpoint-{step}" if passed "--checkpointing-steps step" or "epoch-{epoch}" if passed "--checkpointing-steps epoch".

scripts places code like txt2img.py to sample image given a prompt

test contains scripts to test code, currently only args because packages structure and import still confuses me

utils has helpful scripts for a successful run, includeing ckpt handling, model&data loading and arg parsing.

stable_diffusion is the main package that stores everything to build a model. config stores yaml files created by "accelerate config" command line, models stores assembled models while module contains nessecery blocks to build them.

Problems remaining

Though it can possibly run successfully, several problems yet still remains to be solved.(or just things I have not figured out), and any guidance is appreciated

  • python package dependencies and import rules are still confusing, I think its a better way to learn in practice.
  • Structure of this repo combines transformers and fairseq together, but I'm seeking a better structure for smaller projects.
  • Though I'v used dataclass, it is clearly a better way to build a model through config json file
  • Autoencoder training does not quite work, uses diffusers pretrained autoencoder instead

Reference:

Thanks to the following amazing repositories that helped me build this code:

Model implementation:

origin stable diffusion github

labmlai annotated deep learning paper implementation

Training script:

modified from huggingface/diffusers

Arg parsing:

modify script from facebookresearch/fairseq

Diffusion Code Tutorial:

dome272/Diffusion-Models-pytorch

Aleksa's youtube tutorial

More detailed references and links to .py files are in comments

About

implementation of stable diffusion

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages