From fac45134bf0be34ba6078c7298382d7495470da9 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Fri, 8 Mar 2024 16:27:33 -0800 Subject: [PATCH] fix dtype bug in adam_pax --- MaxText/optimizers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index 0a8150293..e2a23abda 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -117,8 +117,13 @@ def __init__(self, mu, nu): self.nu = nu def _update_momentum(update, mu, nu): - beta1_decay = bias_corrected_decay(count, beta1) - beta2_decay = bias_corrected_decay(count, beta2) + # The conversion to the data type of the update ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_corrected_dacay` is calculated as calculating `jnp.power(decay, t)` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + beta1_decay = bias_corrected_decay(count, beta1).astype(update) + beta2_decay = bias_corrected_decay(count, beta2).astype(update) mu = (1.0 - beta1_decay) * update + beta1_decay * mu nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu return _slot_opt_state(mu=mu, nu=nu)