Skip to content

Commit

Permalink
back to 2**10 for bert loss scaler (tinygrad#6934)
Browse files Browse the repository at this point in the history
getting 2 NaN for this, revert back to 2**10
  • Loading branch information
chenyuxyz authored Oct 7, 2024
1 parent 9250452 commit 102dfe5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/mlperf/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def train_bert():
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)

loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**13 if dtypes.default_float == dtypes.float16 else 1.0)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["DECAY"] = getenv("DECAY", 0.01)
epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
Expand Down

0 comments on commit 102dfe5

Please sign in to comment.