Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement along with autombw-rt #1

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#
_*
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,84 @@
# auto-MBW-rt "runtime enhancement" by 6DammK9 (1.7.0 Tested)

Minor fix and enhancements on top of [Xynonners' Fork](https://github.com/Xynonners/sd-webui-runtime-block-merge). Development along with [my fork of auto-MBW-rt](https://github.com/6DammK9/auto-MBW-rt/tree/master).

## Installation guide

- "Install as usual" (Install these extensions via "Extensions" > "Install from URL", or just clone it in `SD_DIR/extension`). There is no change on AI related stuffs. This is SE focused.

## Major change

- **Add SDXL support.** See next session.

- Fix the awful inconsistancy between UI, runtime, and AutoMBW.

```py
sl_ALL_nat = [*sl_INPUT, *sl_MID, sl_OUT, *sl_OUTPUT, sl_TIME_EMBED]
sl_ALL = [*sl_INPUT, *sl_MID, *sl_OUTPUT, sl_TIME_EMBED, sl_OUT]

# Should be sl_ALL = [*sl_INPUT, *sl_MID, *sl_OUTPUT, sl_OUT, sl_TIME_EMBED]
```

- Add the logger, and probably reformat the log (if possible)

- Add `.gitignore`.

- Merge the info to the `__metadata__` inside the model: [How the original Checkpoint Merger does](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/extras.py#L257)

```py
safetensors.torch.save_file(combined_state_dict, save_checkpoint_path, metadata=metadata if len(metadata)>0 else None)
```

## Notes on merging SDXL models

- Due to ~~messy codebase~~ backward compatibility, it is adviced to restart WebUI if your next merge is in different SD version.

- You will see loads of `Missing key(s) in state_dict:` when the settings in A1111 is not correcly loaded. Keep switching *UI's selected model* to non SDXL models, and try again. If you see a `*.yaml` is loaded, it is usually success. Sadly it is done in A1111 instead of extensions.

- This extension supports most layers, but `label_emb` will be untouched. ~~I don't want to add another "slider" to make things overcomplicated.~~

- **IN09-IN11 and OUT09-OUT11 will be ignored for merging SDXL.** UI / algorithm mappings will be untouched. For AutoMBW, it will still includes all 27 parameters to optimize, but you can expect they are *irrelevant* to the merge, and the algorithm (e.g. `BayesianOptimization`) will take care of it. ~~noisy environment is common in AI/ML, so no worry.~~

```py
# To support SDXL, we have found that IN09-IN11 and OUT09-OUT11 are not exist, then we can skip them.
# Meanwhile label_emb is SDXL exclusive, we need to ignore it.
# Before: current_block_index++
```

## Read metadata inside the model

- Read by [safetensors_util](https://github.com/by321/safetensors_util), or read the `extras.py`.

- If success, you will see a JSON string:

```json
{
"__metadata__": {
"sd_mbwrt_recipe": {
"type": "auto-mbw-rt",
"modelA": "_03a-mzpikas_tmnd_enhanced-sd-v1-4.safetensors",
"modelB": "_04a-dreamshaper_8-sd-v1-4.safetensors",
"modelO": "03b-verify",
"position_id_fix": "Keep Original",
"output_mode": "Max Precision",
"mbwrt_weights": "0.3,0.8,1,0.5,1,0,0.7,0.9,0.5,0.9,0.2,0.2,0.6,0,0.3,0.3,1,1,0.8,0.2,0.7,0,1,0.3,0.2,0.9,0.2",
"mbwrt_weights_seq": "[*sl_INPUT, *sl_MID, *sl_OUTPUT, sl_OUT, sl_TIME_EMBED]"
},
"format": "pt"
}
}
```

## Some observation.

- The very bottom of "merging algorithm" is `torch.lerp`. Therefore efficiency is good, but it won't support add-diff.

## This is part of my research.

- Just a hobby. [If you are feared by tuning for numbers, try "averaging" by simply 0.5, 0.33, 0.25... for 20 models. It works.](https://github.com/6DammK9/nai-anime-pure-negative-prompt/tree/main/ch05).

------

# sd-webui-runtime-block-merge

Runtime version of Merge Block Weighted - GUI, let's you preview the effect before you merge.
Expand Down
2 changes: 2 additions & 0 deletions metadata.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[Extension]
Name=sd-webui-runtime-block-merge
41 changes: 41 additions & 0 deletions scripts/mbw_util/mbwrt_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import copy
import logging
import sys

from modules import shared


class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[0;36m", # CYAN
"INFO": "\033[0;32m", # GREEN
"WARNING": "\033[0;33m", # YELLOW
"ERROR": "\033[0;31m", # RED
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
"RESET": "\033[0m", # RESET COLOR
}

def format(self, record):
colored_record = copy.copy(record)
levelname = colored_record.levelname
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
return super().format(colored_record)


# Create a new logger
logger_mbwrt = logging.getLogger("MBWrt")
logger_mbwrt.propagate = False

# Add handler if we don't have one.
if not logger_mbwrt.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger_mbwrt.addHandler(handler)

# Configure logger
loglevel_string = getattr(shared.cmd_opts, "mbwrt_loglevel", "DEBUG")
loglevel = getattr(logging, loglevel_string.upper(), None)
logger_mbwrt.setLevel(loglevel)
Loading