Skip to content

Commit

Permalink
Merge branch 'schedule_free' of https://github.com/sdbds/sd-scripts i…
Browse files Browse the repository at this point in the history
…nto qinglong
  • Loading branch information
sdbds committed Sep 15, 2024
2 parents d59780f + 01f271b commit b514704
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
35 changes: 19 additions & 16 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3334,7 +3334,7 @@ def int_or_float(value):
)

parser.add_argument(
"--schedulefree_wrapper_kwargs",
"--schedulefree_wrapper_args",
type=str,
default=None,
nargs="*",
Expand Down Expand Up @@ -4460,6 +4460,8 @@ def get_optimizer(args, trainable_params, model=None):
optimizer_kwargs[key] = value
# logger.info(f"optkwargs {optimizer}_{kwargs}")

schedulefree_wrapper_kwargs = {}

lr = args.learning_rate
optimizer = None

Expand Down Expand Up @@ -4800,30 +4802,31 @@ def get_optimizer(args, trainable_params, model=None):
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree"):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")

if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
schedulefree_wrapper_kwargs[key] = value
optimizer = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]

optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])

if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree"):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
schedulefree_wrapper_kwargs = {}
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_kwargs:
key, value = arg.split("=")
value = ast.literal_eval(value)
schedulefree_wrapper_kwargs[key] = value
optimizer = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)

return optimizer_name, optimizer_args, optimizer


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.3
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7
torch-optimi==0.2.1
adam-mini==1.0.1
tensorboard
Expand Down

0 comments on commit b514704

Please sign in to comment.