Skip to content

Commit

Permalink
[WebNN EP] Decompose Concat with input number > 4 for CPU backend (mi…
Browse files Browse the repository at this point in the history
…crosoft#18930)

WebNN XNNPack backend only supports the concat with inputs number <= 4,
decomposing the Concat with inputs number > 4 into multiple WebNN concat
ops.
  • Loading branch information
Honry authored Dec 29, 2023
1 parent a3626b6 commit 96d1f32
Showing 1 changed file with 31 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,40 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
NodeAttrHelper helper(node);
uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));

emscripten::val inputs = emscripten::val::array();
for (const auto* input : node.InputDefs()) {
const size_t num_inputs = input_defs.size();
std::vector<emscripten::val> inputs;
for (const auto* input : input_defs) {
LOGS(logger, VERBOSE) << "input name " << input->Name();
inputs.call<void>("push", model_builder.GetOperand(input->Name()));
inputs.push_back(model_builder.GetOperand(input->Name()));
}

emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("concat", inputs, axis);
emscripten::val output = emscripten::val::undefined();
if (num_inputs <= 4 || model_builder.GetPreferredLayout() == DataLayout::NCHW) {
output = model_builder.GetBuilder().call<emscripten::val>("concat", emscripten::val::array(inputs), axis);
} else {
// WebNN XNNPack backend only supports the concat with inputs number <= 4,
// decomposing the Concat with inputs number > 4 into multiple WebNN concat ops.
size_t remaining_inputs = num_inputs;
size_t max_inputs = 4;
while (remaining_inputs > 0) {
std::vector<emscripten::val> chunk_inputs;

// Push the last concated output to the next chunk_inputs.
if (output != emscripten::val::undefined()) {
chunk_inputs.push_back(output);
max_inputs = 3;
}

size_t chunk_size = std::min(remaining_inputs, max_inputs);

for (size_t i = 0; i < chunk_size; i++) {
chunk_inputs.push_back(inputs[num_inputs - remaining_inputs + i]);
}

output = model_builder.GetBuilder().call<emscripten::val>("concat", emscripten::val::array(chunk_inputs), axis);
remaining_inputs -= chunk_size;
}
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand Down

0 comments on commit 96d1f32

Please sign in to comment.