From a5726c75068905e7ba4509ba7aed0e159a0405e9 Mon Sep 17 00:00:00 2001 From: Xu <34770031+Blinue@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:22:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E4=BB=BB=E6=84=8F?= =?UTF-8?q?=E7=BC=A9=E6=94=BE=E5=80=8D=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Magpie.Core/CudaInferenceBackend.cpp | 25 ++++++------- src/Magpie.Core/CudaInferenceBackend.h | 2 ++ src/Magpie.Core/DirectMLInferenceBackend.cpp | 37 +++++++++++--------- src/Magpie.Core/DirectMLInferenceBackend.h | 5 ++- src/Magpie.Core/InferenceBackendBase.h | 1 + src/Magpie.Core/OnnxEffectDrawer.cpp | 22 ++++++++++-- 6 files changed, 59 insertions(+), 33 deletions(-) diff --git a/src/Magpie.Core/CudaInferenceBackend.cpp b/src/Magpie.Core/CudaInferenceBackend.cpp index d79280b5f..44efd18ee 100644 --- a/src/Magpie.Core/CudaInferenceBackend.cpp +++ b/src/Magpie.Core/CudaInferenceBackend.cpp @@ -7,7 +7,6 @@ #include "BackendDescriptorStore.h" #include "Logger.h" #include "DirectXHelper.h" -#include #include "Utils.h" #pragma comment(lib, "cudart.lib") @@ -29,6 +28,7 @@ CudaInferenceBackend::~CudaInferenceBackend() { bool CudaInferenceBackend::Initialize( const wchar_t* modelPath, + uint32_t scale, DeviceResources& deviceResources, BackendDescriptorStore& descriptorStore, ID3D11Texture2D* input, @@ -59,7 +59,6 @@ bool CudaInferenceBackend::Initialize( Ort::SessionOptions sessionOptions; sessionOptions.SetIntraOpNumThreads(1); - sessionOptions.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1)); @@ -83,13 +82,14 @@ bool CudaInferenceBackend::Initialize( _d3dDC = deviceResources.GetD3DDC(); _inputSize = DirectXHelper::GetTextureSize(input); + _outputSize = SIZE{ _inputSize.cx * (LONG)scale, _inputSize.cy * (LONG)scale }; // 创建输出纹理 winrt::com_ptr outputTex = DirectXHelper::CreateTexture2D( d3dDevice, DXGI_FORMAT_R8G8B8A8_UNORM, - _inputSize.cx * 2, - _inputSize.cy * 2, + _outputSize.cx, + _outputSize.cy, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS ); if (!outputTex) { @@ -98,13 +98,14 @@ bool CudaInferenceBackend::Initialize( } *output = outputTex.get(); - const uint32_t elemCount = uint32_t(_inputSize.cx * _inputSize.cy * 3); + const uint32_t inputElemCount = uint32_t(_inputSize.cx * _inputSize.cy * 3); + const uint32_t outputElemCount = uint32_t(_outputSize.cx * _outputSize.cy * 3); winrt::com_ptr inputBuffer; winrt::com_ptr outputBuffer; { D3D11_BUFFER_DESC desc{ - .ByteWidth = _isFP16Data ? ((elemCount + 1) / 2 * 4) : (elemCount * 4), + .ByteWidth = _isFP16Data ? ((inputElemCount + 1) / 2 * 4) : (inputElemCount * 4), .BindFlags = D3D11_BIND_UNORDERED_ACCESS }; HRESULT hr = d3dDevice->CreateBuffer(&desc, nullptr, inputBuffer.put()); @@ -113,7 +114,7 @@ bool CudaInferenceBackend::Initialize( return false; } - desc.ByteWidth = elemCount * 4 * (_isFP16Data ? 2 : 4); + desc.ByteWidth = _isFP16Data ? ((outputElemCount + 1) / 2 * 4) : (outputElemCount * 4); desc.BindFlags = D3D11_BIND_SHADER_RESOURCE; hr = d3dDevice->CreateBuffer(&desc, nullptr, outputBuffer.put()); if (FAILED(hr)) { @@ -140,7 +141,7 @@ bool CudaInferenceBackend::Initialize( .Format = _isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT, .ViewDimension = D3D11_UAV_DIMENSION_BUFFER, .Buffer{ - .NumElements = elemCount + .NumElements = inputElemCount } }; @@ -157,7 +158,7 @@ bool CudaInferenceBackend::Initialize( .Format = _isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT, .ViewDimension = D3D11_SRV_DIMENSION_BUFFER, .Buffer{ - .NumElements = elemCount * 4 + .NumElements = outputElemCount } }; @@ -202,8 +203,8 @@ bool CudaInferenceBackend::Initialize( (_inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second }; _tensorToTexDispatchCount = { - (_inputSize.cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first, - (_inputSize.cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second + (_outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first, + (_outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second }; cudaResult = cudaGraphicsD3D11RegisterResource( @@ -275,7 +276,7 @@ void CudaInferenceBackend::Evaluate() noexcept { std::size(inputShape), _isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ); - const int64_t outputShape[]{ 1,3,_inputSize.cy * 2,_inputSize.cx * 2 }; + const int64_t outputShape[]{ 1,3,_outputSize.cy,_outputSize.cx }; Ort::Value outputValue = Ort::Value::CreateTensor( _cudaMemInfo, outputMem, diff --git a/src/Magpie.Core/CudaInferenceBackend.h b/src/Magpie.Core/CudaInferenceBackend.h index f50d37f1e..7ee8873fa 100644 --- a/src/Magpie.Core/CudaInferenceBackend.h +++ b/src/Magpie.Core/CudaInferenceBackend.h @@ -15,6 +15,7 @@ class CudaInferenceBackend : public InferenceBackendBase { bool Initialize( const wchar_t* modelPath, + uint32_t scale, DeviceResources& deviceResources, BackendDescriptorStore& descriptorStore, ID3D11Texture2D* input, @@ -56,6 +57,7 @@ class CudaInferenceBackend : public InferenceBackendBase { Ort::MemoryInfo _cudaMemInfo{ nullptr }; SIZE _inputSize{}; + SIZE _outputSize{}; const char* _inputName = nullptr; const char* _outputName = nullptr; diff --git a/src/Magpie.Core/DirectMLInferenceBackend.cpp b/src/Magpie.Core/DirectMLInferenceBackend.cpp index 7023ad795..45a6599b6 100644 --- a/src/Magpie.Core/DirectMLInferenceBackend.cpp +++ b/src/Magpie.Core/DirectMLInferenceBackend.cpp @@ -5,7 +5,6 @@ #include "shaders/TensorToTextureCS.h" #include "shaders/TextureToTensorCS.h" #include "Logger.h" -#include #include #include "Win32Utils.h" @@ -100,6 +99,7 @@ static winrt::com_ptr AllocateD3D12Resource(const OrtDmlApi* ortDmlApi bool DirectMLInferenceBackend::Initialize( const wchar_t* modelPath, + uint32_t scale, DeviceResources& deviceResources, BackendDescriptorStore& /*descriptorStore*/, ID3D11Texture2D* input, @@ -109,13 +109,14 @@ bool DirectMLInferenceBackend::Initialize( _d3d11DC = deviceResources.GetD3DDC(); const SIZE inputSize = DirectXHelper::GetTextureSize(input); + const SIZE outputSize{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale }; // 创建输出纹理 _outputTex = DirectXHelper::CreateTexture2D( d3d11Device, DXGI_FORMAT_R8G8B8A8_UNORM, - inputSize.cx * 2, - inputSize.cy * 2, + outputSize.cx, + outputSize.cy, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS, D3D11_USAGE_DEFAULT, D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE @@ -126,7 +127,8 @@ bool DirectMLInferenceBackend::Initialize( } *output = _outputTex.get(); - const uint32_t elemCount = uint32_t(inputSize.cx * inputSize.cy * 3); + const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3); + const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3); winrt::com_ptr d3d12Device = CreateD3D12Device(deviceResources.GetGraphicsAdapter()); if (!d3d12Device) { @@ -160,7 +162,6 @@ bool DirectMLInferenceBackend::Initialize( sessionOptions.SetIntraOpNumThreads(1); sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); sessionOptions.DisableMemPattern(); - sessionOptions.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1)); @@ -187,7 +188,7 @@ bool DirectMLInferenceBackend::Initialize( }; D3D12_RESOURCE_DESC resDesc{ .Dimension = D3D12_RESOURCE_DIMENSION_BUFFER, - .Width = elemCount * (isFP16Data ? 2 : 4), + .Width = inputElemCount * (isFP16Data ? 2 : 4), .Height = 1, .DepthOrArraySize = 1, .MipLevels = 1, @@ -209,7 +210,7 @@ bool DirectMLInferenceBackend::Initialize( return false; } - resDesc.Width *= 4; + resDesc.Width = UINT64(outputElemCount * (isFP16Data ? 2 : 4)); hr = d3d12Device->CreateCommittedResource( &heapDesc, D3D12_HEAP_FLAG_CREATE_NOT_ZEROED, @@ -241,18 +242,18 @@ bool DirectMLInferenceBackend::Initialize( _ioBinding.BindInput("input", Ort::Value::CreateTensor( memoryInfo, _allocatedInput.get(), - size_t(elemCount * (isFP16Data ? 2 : 4)), + size_t(inputElemCount * (isFP16Data ? 2 : 4)), inputShape, std::size(inputShape), dataType )); - const int64_t outputShape[]{ 1,3,inputSize.cy * 2,inputSize.cx * 2 }; + const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx }; _allocatedOutput = AllocateD3D12Resource(ortDmlApi, _outputBuffer.get()); _ioBinding.BindOutput("output", Ort::Value::CreateTensor( memoryInfo, _allocatedOutput.get(), - size_t(elemCount * 4 * (isFP16Data ? 2 : 4)), + size_t(outputElemCount * (isFP16Data ? 2 : 4)), outputShape, std::size(outputShape), dataType @@ -276,7 +277,7 @@ bool DirectMLInferenceBackend::Initialize( } UINT descriptorSize; - if (!_CreateCBVHeap(d3d12Device.get(), elemCount, isFP16Data, descriptorSize)) { + if (!_CreateCBVHeap(d3d12Device.get(), inputElemCount, outputElemCount, isFP16Data, descriptorSize)) { Logger::Get().Error("_CreateCBVHeap 失败"); return false; } @@ -286,7 +287,7 @@ bool DirectMLInferenceBackend::Initialize( return false; } - if (!_CalcCommandLists(d3d12Device.get(), inputSize, descriptorSize)) { + if (!_CalcCommandLists(d3d12Device.get(), inputSize, outputSize, descriptorSize)) { Logger::Get().Error("_CalcCommandLists 失败"); return false; } @@ -368,7 +369,8 @@ bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12De bool DirectMLInferenceBackend::_CreateCBVHeap( ID3D12Device* d3d12Device, - uint32_t elemCount, + uint32_t inputElemCount, + uint32_t outputElemCount, bool isFP16Data, UINT& descriptorSize ) noexcept { @@ -398,7 +400,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap( .Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT, .ViewDimension = D3D12_UAV_DIMENSION_BUFFER, .Buffer{ - .NumElements = elemCount + .NumElements = inputElemCount } }; d3d12Device->CreateUnorderedAccessView(_inputBuffer.get(), nullptr, &desc, cbvHandle); @@ -411,7 +413,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap( .ViewDimension = D3D12_SRV_DIMENSION_BUFFER, .Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING, .Buffer{ - .NumElements = elemCount * 4 + .NumElements = outputElemCount } }; d3d12Device->CreateShaderResourceView(_outputBuffer.get(), &desc, cbvHandle); @@ -511,6 +513,7 @@ bool DirectMLInferenceBackend::_CreatePipelineStates(ID3D12Device* d3d12Device) bool DirectMLInferenceBackend::_CalcCommandLists( ID3D12Device* d3d12Device, SIZE inputSize, + SIZE outputSize, UINT descriptorSize ) noexcept { winrt::com_ptr d3d12CommandAllocator; @@ -579,8 +582,8 @@ bool DirectMLInferenceBackend::_CalcCommandLists( static constexpr std::pair TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 }; _tensor2TexCommandList->Dispatch( - (inputSize.cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first, - (inputSize.cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second, + (outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first, + (outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second, 1 ); hr = _tensor2TexCommandList->Close(); diff --git a/src/Magpie.Core/DirectMLInferenceBackend.h b/src/Magpie.Core/DirectMLInferenceBackend.h index 39b18e065..cb557535b 100644 --- a/src/Magpie.Core/DirectMLInferenceBackend.h +++ b/src/Magpie.Core/DirectMLInferenceBackend.h @@ -14,6 +14,7 @@ class DirectMLInferenceBackend : public InferenceBackendBase { bool Initialize( const wchar_t* modelPath, + uint32_t scale, DeviceResources& deviceResources, BackendDescriptorStore& descriptorStore, ID3D11Texture2D* input, @@ -27,7 +28,8 @@ class DirectMLInferenceBackend : public InferenceBackendBase { bool _CreateCBVHeap( ID3D12Device* d3d12Device, - uint32_t elemCount, + uint32_t inputElemCount, + uint32_t outputElemCount, bool isFP16Data, UINT& descriptorSize ) noexcept; @@ -37,6 +39,7 @@ class DirectMLInferenceBackend : public InferenceBackendBase { bool _CalcCommandLists( ID3D12Device* d3d12Device, SIZE inputSize, + SIZE outputSize, UINT descriptorSize ) noexcept; diff --git a/src/Magpie.Core/InferenceBackendBase.h b/src/Magpie.Core/InferenceBackendBase.h index 077be5e23..a17c0e72a 100644 --- a/src/Magpie.Core/InferenceBackendBase.h +++ b/src/Magpie.Core/InferenceBackendBase.h @@ -16,6 +16,7 @@ class InferenceBackendBase { virtual bool Initialize( const wchar_t* modelPath, + uint32_t scale, DeviceResources& deviceResources, BackendDescriptorStore& descriptorStore, ID3D11Texture2D* input, diff --git a/src/Magpie.Core/OnnxEffectDrawer.cpp b/src/Magpie.Core/OnnxEffectDrawer.cpp index 299c2beac..c470730de 100644 --- a/src/Magpie.Core/OnnxEffectDrawer.cpp +++ b/src/Magpie.Core/OnnxEffectDrawer.cpp @@ -14,7 +14,12 @@ OnnxEffectDrawer::OnnxEffectDrawer() {} OnnxEffectDrawer::~OnnxEffectDrawer() {} -static bool ReadJson(const rapidjson::Document& doc, std::string& modelPath, std::string& backend) noexcept { +static bool ReadJson( + const rapidjson::Document& doc, + std::string& modelPath, + uint32_t& scale, + std::string& backend +) noexcept { if (!doc.IsObject()) { Logger::Get().Error("根元素不是 Object"); return false; @@ -32,6 +37,16 @@ static bool ReadJson(const rapidjson::Document& doc, std::string& modelPath, std modelPath = node->value.GetString(); } + { + auto node = root.FindMember("scale"); + if (node == root.MemberEnd() || !node->value.IsUint()) { + Logger::Get().Error("解析 scale 失败"); + return false; + } + + scale = node->value.GetUint(); + } + { auto node = root.FindMember("backend"); if (node == root.MemberEnd() || !node->value.IsString()) { @@ -62,6 +77,7 @@ bool OnnxEffectDrawer::Initialize( } std::string modelPath; + uint32_t scale = 1; std::string backend; { rapidjson::Document doc; @@ -71,7 +87,7 @@ bool OnnxEffectDrawer::Initialize( return false; } - if (!ReadJson(doc, modelPath, backend)) { + if (!ReadJson(doc, modelPath, scale, backend)) { Logger::Get().Error("ReadJson 失败"); return false; } @@ -90,7 +106,7 @@ bool OnnxEffectDrawer::Initialize( } std::wstring modelPathW = StrUtils::UTF8ToUTF16(modelPath); - if (!_inferenceBackend->Initialize(modelPathW.c_str(), deviceResources, descriptorStore, *inOutTexture, inOutTexture)) { + if (!_inferenceBackend->Initialize(modelPathW.c_str(), scale, deviceResources, descriptorStore, *inOutTexture, inOutTexture)) { return false; }