From ab7b23187062db86d34fc82db95f7266a68ab5c4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 19:38:52 +0800 Subject: [PATCH] init --- library/train_util.py | 21 ++++++++++++++++++--- requirements.txt | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e1..bdf7774e4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2994,7 +2994,7 @@ def int_or_float(value): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -4032,7 +4032,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -4141,7 +4141,22 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - + elif optimizer_type == "Ademamix8bit".lower(): + logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.AdEMAMix8bit + except AttributeError: + raise AttributeError( + "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdemamix8bit".lower(): + logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdEMAMix8bit + except AttributeError: + raise AttributeError( + "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): diff --git a/requirements.txt b/requirements.txt index 15e6e58f1..e6e1bf6fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.1.1 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard