Skip to content

Commit

Permalink
update for schedulefree
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Sep 28, 2024
1 parent 9829362 commit 7578a9e
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4842,6 +4842,23 @@ def get_optimizer(args, trainable_params, model=None):
raise ImportError("No sara / sara がインストールされていないようです")
optimizer = optimizer_class(model, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
optimizer.train()

if optimizer is None:
# 任意のoptimizerを使う
case_sensitive_optimizer_type = args.optimizer_type # not lower
Expand Down

0 comments on commit 7578a9e

Please sign in to comment.