Skip to content

Commit

Permalink
Merge branch 'ademamix8bit' of https://github.com/sdbds/sd-scripts in…
Browse files Browse the repository at this point in the history
…to qinglong

# Conflicts:
#	library/train_util.py
#	requirements.txt
  • Loading branch information
sdbds committed Sep 25, 2024
2 parents 5747c20 + ab7b231 commit f77de6b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
21 changes: 18 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3297,7 +3297,7 @@ def int_or_float(value):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
)

# backward compatibility
Expand Down Expand Up @@ -4410,7 +4410,7 @@ def task():


def get_optimizer(args, trainable_params, model=None):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"

optimizer_type = args.optimizer_type
if args.use_8bit_adam:
Expand Down Expand Up @@ -4528,7 +4528,22 @@ def get_optimizer(args, trainable_params, model=None):
raise AttributeError(
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)

elif optimizer_type == "Ademamix8bit".lower():
logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}")
try:
optimizer_class = bnb.optim.AdEMAMix8bit
except AttributeError:
raise AttributeError(
"No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
elif optimizer_type == "PagedAdemamix8bit".lower():
logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}")
try:
optimizer_class = bnb.optim.PagedAdEMAMix8bit
except AttributeError:
raise AttributeError(
"No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "PagedAdamW".lower():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ftfy==6.1.1
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.43.3
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7
Expand Down

0 comments on commit f77de6b

Please sign in to comment.