Skip to content

Commit

Permalink
[CANN] Support cpu offload optimizer for Ascend NPU (microsoft#4568)
Browse files Browse the repository at this point in the history
Support cpu_adam, cpu_adagrad and cpu_lion optimizer for Ascend NPU. All
these optimizer are running on host, the difference between each backend
is the way to copy params back to device. This commit add a new symbol
called "__ENABLE_CANN__". This symbol can compile code adapted to NPU.
The NPU builder adds the required header files and libraries for
compiling, according to CANN's compilation manual.
Note that there's no FusedLion implementation for NPU, test_cpu_lion
test case should disabled until FusedLion optimizer implemented.

Besides, when NPU is selected as the accelerator, ds_report will show
torch_npu and CANN informations.

With this PR, deepspeed test cases in
[huggingface/accelerate](https://github.com/huggingface/accelerate/tree/main/tests/deepspeed)
are all passed.

It's a part of feature list for Ascend NPU support, @see microsoft#4567

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
hipudding and tjruwase authored Nov 14, 2023
1 parent efd4556 commit c1ba6a1
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 29 deletions.
19 changes: 16 additions & 3 deletions csrc/adagrad/cpu_adagrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ void Adagrad_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
Expand All @@ -62,7 +64,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
grad += _eps;
grad = momentum / grad;
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
Expand All @@ -79,6 +81,17 @@ void Adagrad_Optimizer::Step_1(float* _params,
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
}
Expand Down Expand Up @@ -180,7 +193,7 @@ int ds_adagrad_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());

#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
Expand All @@ -196,7 +209,7 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
Expand Down
27 changes: 20 additions & 7 deletions csrc/adam/cpu_adam_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ void Adam_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
Expand All @@ -81,7 +83,7 @@ void Adam_Optimizer::Step_1(float* _params,
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
Expand All @@ -96,6 +98,17 @@ void Adam_Optimizer::Step_1(float* _params,
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);

_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
Expand Down Expand Up @@ -239,7 +252,7 @@ int ds_adam_step(int optimizer_id,
nullptr,
(params.options().dtype() == at::kHalf));

#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
Expand All @@ -257,18 +270,18 @@ int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
torch::Tensor& device_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto device_params_c = device_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

Expand All @@ -281,7 +294,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
device_params_ptr,
(params.options().dtype() == at::kHalf));

opt->SynchronizeStreams();
Expand Down
37 changes: 36 additions & 1 deletion csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
typedef unsigned short ds_half_precision_t;
#endif
Expand All @@ -41,6 +45,11 @@ class Adagrad_Optimizer {

_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_buf_index = false;
#endif
}
Expand All @@ -49,6 +58,9 @@ class Adagrad_Optimizer {
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}
#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -69,6 +81,11 @@ class Adagrad_Optimizer {
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step)
{
Expand All @@ -95,6 +112,11 @@ class Adagrad_Optimizer {
bool _buf_index;
float* _doubled_buffer[2];
cudaStream_t _streams[2];
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};

Expand Down Expand Up @@ -125,6 +147,8 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand All @@ -149,7 +173,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);

simd_store<span>(_params + i, param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
Expand All @@ -167,6 +191,17 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,

_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;
Expand Down
38 changes: 37 additions & 1 deletion csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
Expand Down Expand Up @@ -57,6 +61,11 @@ class Adam_Optimizer {

_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_buf_index = false;
#endif
}
Expand All @@ -65,6 +74,9 @@ class Adam_Optimizer {
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}

Expand All @@ -87,6 +99,11 @@ class Adam_Optimizer {
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step, float beta1, float beta2)
{
Expand Down Expand Up @@ -142,6 +159,11 @@ class Adam_Optimizer {
float* _doubled_buffer[2];
cudaStream_t _streams[2];
bool _buf_index;
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};

Expand Down Expand Up @@ -192,6 +214,9 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
}
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand Down Expand Up @@ -227,7 +252,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);

simd_store<span>(_params + (i >> rshft), param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
Expand All @@ -246,6 +271,17 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,

_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;
Expand Down
Loading

0 comments on commit c1ba6a1

Please sign in to comment.