Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Nov 3, 2024
1 parent 87f507d commit 4113ace
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) {
}
}

std::unique_ptr < ReductionParams >
2dInnerReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const size_t vectorize_factor) {
std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const size_t vectorize_factor) {
// Set some targets for parallelization

const int64_t n_elems = total_reduction_numel * total_iteration_numel;
Expand Down Expand Up @@ -493,7 +492,7 @@ std::unique_ptr < ReductionParams >
<< (rparams->unroll_factor_inner_reduction > 1) << ", "
<< rparams->cross_grid_inner_reduction << std::endl;
}
return innerReductionHeuristic(
return inner2dReductionHeuristic(
total_reduction_numel,
total_iteration_numel,
total_reduction_numel,
Expand All @@ -506,14 +505,13 @@ std::unique_ptr < ReductionParams >
return rparams;
}

std::unique_ptr < ReductionParams >
3dInnerReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const size_t vectorize_factor) {
std::unique_ptr<ReductionParams> inner3dReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const size_t vectorize_factor) {
// Set some targets for parallelization

const int64_t n_elems = total_reduction_numel * total_iteration_numel;
Expand Down Expand Up @@ -936,7 +934,7 @@ std::unique_ptr < ReductionParams >
<< (rparams->unroll_factor_inner_reduction > 1) << ", "
<< rparams->cross_grid_inner_reduction << std::endl;
}
return innerReductionHeuristic(
return inner2dReductionHeuristic(
total_reduction_numel,
total_iteration_numel,
total_reduction_numel,
Expand Down Expand Up @@ -1497,15 +1495,15 @@ std::unique_ptr<ReductionParams> reductionHeuristic(
const size_t vectorize_factor) {
if (fastest_dim_reduction) {
if (total_reduction_numel == inner_most_dimension_numel) {
return 2dInnerReductionHeuristic(
return inner2dReductionHeuristic(
total_reduction_numel,
total_iteration_numel,
inner_most_dimension_numel,
(int64_t)n_tensor_inputs,
(int64_t)max_input_dtype_size,
vectorize_factor);
} else {
return 3dInnerReductionHeuristic(
return inner3dReductionHeuristic(
total_reduction_numel,
total_iteration_numel,
inner_most_dimension_numel,
Expand Down

0 comments on commit 4113ace

Please sign in to comment.