Skip to content

Commit

Permalink
fix optimizer args as same dtype on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng committed Nov 23, 2024
1 parent a4b0072 commit 2fcc528
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions mindnlp/core/ops/optim.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
"""optim op"""
import mindspore
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim

DEVICE_TARGET = mindspore.get_context('device_target')

_adadelta = ops.ApplyAdadelta()
def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad):
return _adadelta(param, square_avg, acc_delta, lr, rho, eps, grad)

_adam = ops.Adam()
def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
# var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad
if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32:
beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \
mindspore.tensor(beta2_power, dtype=param.dtype), \
mindspore.tensor(lr, dtype=param.dtype), \
mindspore.tensor(beta1, dtype=param.dtype), \
mindspore.tensor(beta2, dtype=param.dtype), \
mindspore.tensor(epsilon, dtype=param.dtype)
return _adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)

_adam_amsgrad = ops.ApplyAdamWithAmsgradV2()
def raw_adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
# var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad

if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32:
beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \
mindspore.tensor(beta2_power, dtype=param.dtype), \
mindspore.tensor(lr, dtype=param.dtype), \
mindspore.tensor(beta1, dtype=param.dtype), \
mindspore.tensor(beta2, dtype=param.dtype), \
mindspore.tensor(epsilon, dtype=param.dtype)

return _adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq,
beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)

Expand Down

0 comments on commit 2fcc528

Please sign in to comment.