Skip to content

Commit

Permalink
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
Browse files Browse the repository at this point in the history
# Conflicts:
#	sd3_train.py
  • Loading branch information
sdbds committed Sep 23, 2024
2 parents cd4f7ee + fba7692 commit e44d8f9
Show file tree
Hide file tree
Showing 34 changed files with 2,774 additions and 448 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4

- name: typos-action
uses: crate-ci/typos@v1.21.0
uses: crate-ci/typos@v1.24.3
283 changes: 260 additions & 23 deletions README.md

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -456,7 +459,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
Expand All @@ -469,7 +472,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand Down
8 changes: 5 additions & 3 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

import library.train_util as train_util
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize

setup_logging()
import logging
Expand Down Expand Up @@ -42,8 +42,10 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))

image = image.astype(np.float32)
return image
Expand Down
Loading

0 comments on commit e44d8f9

Please sign in to comment.