Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge LoCo with Zero++ #6730

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
30 changes: 30 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ void launch_swizzled_quant(int8_t* q_data,
int devices_per_node,
cudaStream_t stream);

void launch_loco_swizzled_quant(int8_t* quantized_data,
float* quantized_scales,
const __half* uncompressed_data,
__half* error_feedback,
const float err_beta,
int num_bits,
quantize::Type quant_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream);

void launch_loco_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
__half2* error_feedback,
const float err_beta,
cudaStream_t stream);

void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
Expand Down
1 change: 1 addition & 0 deletions csrc/includes/quantization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ constexpr int max_threads = 1024;
Class to hold the quantization parameters for a given tensor.
Holds the implementation of the quantization operation.
*/

template <Type qType, int numBits>
class Params {
public:
Expand Down
106 changes: 106 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,53 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
return output;
}

std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;

auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;

launch_loco_swizzled_quant(reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());

return {output, scales};
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -265,6 +312,61 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
return {output, scales};
}

std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;

auto scales = torch::empty({out_groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node;

const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;

auto output = torch::empty(sz, output_options);

const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_loco_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(__half2*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream());

return {output, scales};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
Expand Down Expand Up @@ -295,4 +397,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
m.def("loco_quantized_reduction",
&loco_quantized_reduction,
"LoCo Quantization and Reduction Kernel");
}
Loading