From 2fcc528f8c362742a6e58f092507665f3d93f158 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Sat, 23 Nov 2024 23:21:22 +0800 Subject: [PATCH] fix optimizer args as same dtype on GPU --- mindnlp/core/ops/optim.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mindnlp/core/ops/optim.py b/mindnlp/core/ops/optim.py index cfe4c86c5..0f61d8d52 100644 --- a/mindnlp/core/ops/optim.py +++ b/mindnlp/core/ops/optim.py @@ -1,7 +1,10 @@ """optim op""" +import mindspore from mindspore import ops from mindspore.ops._primitive_cache import _get_cache_prim +DEVICE_TARGET = mindspore.get_context('device_target') + _adadelta = ops.ApplyAdadelta() def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad): return _adadelta(param, square_avg, acc_delta, lr, rho, eps, grad) @@ -9,11 +12,27 @@ def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad): _adam = ops.Adam() def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: + beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ + mindspore.tensor(beta2_power, dtype=param.dtype), \ + mindspore.tensor(lr, dtype=param.dtype), \ + mindspore.tensor(beta1, dtype=param.dtype), \ + mindspore.tensor(beta2, dtype=param.dtype), \ + mindspore.tensor(epsilon, dtype=param.dtype) return _adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) _adam_amsgrad = ops.ApplyAdamWithAmsgradV2() def raw_adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad + + if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: + beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ + mindspore.tensor(beta2_power, dtype=param.dtype), \ + mindspore.tensor(lr, dtype=param.dtype), \ + mindspore.tensor(beta1, dtype=param.dtype), \ + mindspore.tensor(beta2, dtype=param.dtype), \ + mindspore.tensor(epsilon, dtype=param.dtype) + return _adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)