diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index c39927b3cc26b..d3fa00e5fe32b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -36,13 +36,40 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); - emscripten::val inputs = emscripten::val::array(); - for (const auto* input : node.InputDefs()) { + const size_t num_inputs = input_defs.size(); + std::vector inputs; + for (const auto* input : input_defs) { LOGS(logger, VERBOSE) << "input name " << input->Name(); - inputs.call("push", model_builder.GetOperand(input->Name())); + inputs.push_back(model_builder.GetOperand(input->Name())); } - emscripten::val output = model_builder.GetBuilder().call("concat", inputs, axis); + emscripten::val output = emscripten::val::undefined(); + if (num_inputs <= 4 || model_builder.GetPreferredLayout() == DataLayout::NCHW) { + output = model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis); + } else { + // WebNN XNNPack backend only supports the concat with inputs number <= 4, + // decomposing the Concat with inputs number > 4 into multiple WebNN concat ops. + size_t remaining_inputs = num_inputs; + size_t max_inputs = 4; + while (remaining_inputs > 0) { + std::vector chunk_inputs; + + // Push the last concated output to the next chunk_inputs. + if (output != emscripten::val::undefined()) { + chunk_inputs.push_back(output); + max_inputs = 3; + } + + size_t chunk_size = std::min(remaining_inputs, max_inputs); + + for (size_t i = 0; i < chunk_size; i++) { + chunk_inputs.push_back(inputs[num_inputs - remaining_inputs + i]); + } + + output = model_builder.GetBuilder().call("concat", emscripten::val::array(chunk_inputs), axis); + remaining_inputs -= chunk_size; + } + } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK();