Skip to content

Commit

Permalink
feat: 支持任意缩放倍率
Browse files Browse the repository at this point in the history
  • Loading branch information
Blinue committed Mar 10, 2024
1 parent 69416af commit a5726c7
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 33 deletions.
25 changes: 13 additions & 12 deletions src/Magpie.Core/CudaInferenceBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "BackendDescriptorStore.h"
#include "Logger.h"
#include "DirectXHelper.h"
#include <onnxruntime/core/session/onnxruntime_session_options_config_keys.h>
#include "Utils.h"

#pragma comment(lib, "cudart.lib")
Expand All @@ -29,6 +28,7 @@ CudaInferenceBackend::~CudaInferenceBackend() {

bool CudaInferenceBackend::Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
Expand Down Expand Up @@ -59,7 +59,6 @@ bool CudaInferenceBackend::Initialize(

Ort::SessionOptions sessionOptions;
sessionOptions.SetIntraOpNumThreads(1);
sessionOptions.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1");

Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));

Expand All @@ -83,13 +82,14 @@ bool CudaInferenceBackend::Initialize(
_d3dDC = deviceResources.GetD3DDC();

_inputSize = DirectXHelper::GetTextureSize(input);
_outputSize = SIZE{ _inputSize.cx * (LONG)scale, _inputSize.cy * (LONG)scale };

// 创建输出纹理
winrt::com_ptr<ID3D11Texture2D> outputTex = DirectXHelper::CreateTexture2D(
d3dDevice,
DXGI_FORMAT_R8G8B8A8_UNORM,
_inputSize.cx * 2,
_inputSize.cy * 2,
_outputSize.cx,
_outputSize.cy,
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS
);
if (!outputTex) {
Expand All @@ -98,13 +98,14 @@ bool CudaInferenceBackend::Initialize(
}
*output = outputTex.get();

const uint32_t elemCount = uint32_t(_inputSize.cx * _inputSize.cy * 3);
const uint32_t inputElemCount = uint32_t(_inputSize.cx * _inputSize.cy * 3);
const uint32_t outputElemCount = uint32_t(_outputSize.cx * _outputSize.cy * 3);

winrt::com_ptr<ID3D11Buffer> inputBuffer;
winrt::com_ptr<ID3D11Buffer> outputBuffer;
{
D3D11_BUFFER_DESC desc{
.ByteWidth = _isFP16Data ? ((elemCount + 1) / 2 * 4) : (elemCount * 4),
.ByteWidth = _isFP16Data ? ((inputElemCount + 1) / 2 * 4) : (inputElemCount * 4),
.BindFlags = D3D11_BIND_UNORDERED_ACCESS
};
HRESULT hr = d3dDevice->CreateBuffer(&desc, nullptr, inputBuffer.put());
Expand All @@ -113,7 +114,7 @@ bool CudaInferenceBackend::Initialize(
return false;
}

desc.ByteWidth = elemCount * 4 * (_isFP16Data ? 2 : 4);
desc.ByteWidth = _isFP16Data ? ((outputElemCount + 1) / 2 * 4) : (outputElemCount * 4);
desc.BindFlags = D3D11_BIND_SHADER_RESOURCE;
hr = d3dDevice->CreateBuffer(&desc, nullptr, outputBuffer.put());
if (FAILED(hr)) {
Expand All @@ -140,7 +141,7 @@ bool CudaInferenceBackend::Initialize(
.Format = _isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D11_UAV_DIMENSION_BUFFER,
.Buffer{
.NumElements = elemCount
.NumElements = inputElemCount
}
};

Expand All @@ -157,7 +158,7 @@ bool CudaInferenceBackend::Initialize(
.Format = _isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D11_SRV_DIMENSION_BUFFER,
.Buffer{
.NumElements = elemCount * 4
.NumElements = outputElemCount
}
};

Expand Down Expand Up @@ -202,8 +203,8 @@ bool CudaInferenceBackend::Initialize(
(_inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second
};
_tensorToTexDispatchCount = {
(_inputSize.cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
(_inputSize.cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second
(_outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
(_outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second
};

cudaResult = cudaGraphicsD3D11RegisterResource(
Expand Down Expand Up @@ -275,7 +276,7 @@ void CudaInferenceBackend::Evaluate() noexcept {
std::size(inputShape),
_isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
);
const int64_t outputShape[]{ 1,3,_inputSize.cy * 2,_inputSize.cx * 2 };
const int64_t outputShape[]{ 1,3,_outputSize.cy,_outputSize.cx };
Ort::Value outputValue = Ort::Value::CreateTensor(
_cudaMemInfo,
outputMem,
Expand Down
2 changes: 2 additions & 0 deletions src/Magpie.Core/CudaInferenceBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class CudaInferenceBackend : public InferenceBackendBase {

bool Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
Expand Down Expand Up @@ -56,6 +57,7 @@ class CudaInferenceBackend : public InferenceBackendBase {
Ort::MemoryInfo _cudaMemInfo{ nullptr };

SIZE _inputSize{};
SIZE _outputSize{};

const char* _inputName = nullptr;
const char* _outputName = nullptr;
Expand Down
37 changes: 20 additions & 17 deletions src/Magpie.Core/DirectMLInferenceBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "shaders/TensorToTextureCS.h"
#include "shaders/TextureToTensorCS.h"
#include "Logger.h"
#include <onnxruntime/core/session/onnxruntime_session_options_config_keys.h>
#include <onnxruntime/core/providers/dml/dml_provider_factory.h>
#include "Win32Utils.h"

Expand Down Expand Up @@ -100,6 +99,7 @@ static winrt::com_ptr<IUnknown> AllocateD3D12Resource(const OrtDmlApi* ortDmlApi

bool DirectMLInferenceBackend::Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& /*descriptorStore*/,
ID3D11Texture2D* input,
Expand All @@ -109,13 +109,14 @@ bool DirectMLInferenceBackend::Initialize(
_d3d11DC = deviceResources.GetD3DDC();

const SIZE inputSize = DirectXHelper::GetTextureSize(input);
const SIZE outputSize{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };

// 创建输出纹理
_outputTex = DirectXHelper::CreateTexture2D(
d3d11Device,
DXGI_FORMAT_R8G8B8A8_UNORM,
inputSize.cx * 2,
inputSize.cy * 2,
outputSize.cx,
outputSize.cy,
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS,
D3D11_USAGE_DEFAULT,
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
Expand All @@ -126,7 +127,8 @@ bool DirectMLInferenceBackend::Initialize(
}
*output = _outputTex.get();

const uint32_t elemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3);

winrt::com_ptr<ID3D12Device> d3d12Device = CreateD3D12Device(deviceResources.GetGraphicsAdapter());
if (!d3d12Device) {
Expand Down Expand Up @@ -160,7 +162,6 @@ bool DirectMLInferenceBackend::Initialize(
sessionOptions.SetIntraOpNumThreads(1);
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
sessionOptions.DisableMemPattern();
sessionOptions.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1");

Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));

Expand All @@ -187,7 +188,7 @@ bool DirectMLInferenceBackend::Initialize(
};
D3D12_RESOURCE_DESC resDesc{
.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER,
.Width = elemCount * (isFP16Data ? 2 : 4),
.Width = inputElemCount * (isFP16Data ? 2 : 4),
.Height = 1,
.DepthOrArraySize = 1,
.MipLevels = 1,
Expand All @@ -209,7 +210,7 @@ bool DirectMLInferenceBackend::Initialize(
return false;
}

resDesc.Width *= 4;
resDesc.Width = UINT64(outputElemCount * (isFP16Data ? 2 : 4));
hr = d3d12Device->CreateCommittedResource(
&heapDesc,
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
Expand Down Expand Up @@ -241,18 +242,18 @@ bool DirectMLInferenceBackend::Initialize(
_ioBinding.BindInput("input", Ort::Value::CreateTensor(
memoryInfo,
_allocatedInput.get(),
size_t(elemCount * (isFP16Data ? 2 : 4)),
size_t(inputElemCount * (isFP16Data ? 2 : 4)),
inputShape,
std::size(inputShape),
dataType
));

const int64_t outputShape[]{ 1,3,inputSize.cy * 2,inputSize.cx * 2 };
const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx };
_allocatedOutput = AllocateD3D12Resource(ortDmlApi, _outputBuffer.get());
_ioBinding.BindOutput("output", Ort::Value::CreateTensor(
memoryInfo,
_allocatedOutput.get(),
size_t(elemCount * 4 * (isFP16Data ? 2 : 4)),
size_t(outputElemCount * (isFP16Data ? 2 : 4)),
outputShape,
std::size(outputShape),
dataType
Expand All @@ -276,7 +277,7 @@ bool DirectMLInferenceBackend::Initialize(
}

UINT descriptorSize;
if (!_CreateCBVHeap(d3d12Device.get(), elemCount, isFP16Data, descriptorSize)) {
if (!_CreateCBVHeap(d3d12Device.get(), inputElemCount, outputElemCount, isFP16Data, descriptorSize)) {
Logger::Get().Error("_CreateCBVHeap 失败");
return false;
}
Expand All @@ -286,7 +287,7 @@ bool DirectMLInferenceBackend::Initialize(
return false;
}

if (!_CalcCommandLists(d3d12Device.get(), inputSize, descriptorSize)) {
if (!_CalcCommandLists(d3d12Device.get(), inputSize, outputSize, descriptorSize)) {
Logger::Get().Error("_CalcCommandLists 失败");
return false;
}
Expand Down Expand Up @@ -368,7 +369,8 @@ bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12De

bool DirectMLInferenceBackend::_CreateCBVHeap(
ID3D12Device* d3d12Device,
uint32_t elemCount,
uint32_t inputElemCount,
uint32_t outputElemCount,
bool isFP16Data,
UINT& descriptorSize
) noexcept {
Expand Down Expand Up @@ -398,7 +400,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap(
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
.Buffer{
.NumElements = elemCount
.NumElements = inputElemCount
}
};
d3d12Device->CreateUnorderedAccessView(_inputBuffer.get(), nullptr, &desc, cbvHandle);
Expand All @@ -411,7 +413,7 @@ bool DirectMLInferenceBackend::_CreateCBVHeap(
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
.Buffer{
.NumElements = elemCount * 4
.NumElements = outputElemCount
}
};
d3d12Device->CreateShaderResourceView(_outputBuffer.get(), &desc, cbvHandle);
Expand Down Expand Up @@ -511,6 +513,7 @@ bool DirectMLInferenceBackend::_CreatePipelineStates(ID3D12Device* d3d12Device)
bool DirectMLInferenceBackend::_CalcCommandLists(
ID3D12Device* d3d12Device,
SIZE inputSize,
SIZE outputSize,
UINT descriptorSize
) noexcept {
winrt::com_ptr<ID3D12CommandAllocator> d3d12CommandAllocator;
Expand Down Expand Up @@ -579,8 +582,8 @@ bool DirectMLInferenceBackend::_CalcCommandLists(

static constexpr std::pair<uint32_t, uint32_t> TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 };
_tensor2TexCommandList->Dispatch(
(inputSize.cx * 2 + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
(inputSize.cy * 2 + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second,
(outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
(outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second,
1
);
hr = _tensor2TexCommandList->Close();
Expand Down
5 changes: 4 additions & 1 deletion src/Magpie.Core/DirectMLInferenceBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DirectMLInferenceBackend : public InferenceBackendBase {

bool Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
Expand All @@ -27,7 +28,8 @@ class DirectMLInferenceBackend : public InferenceBackendBase {

bool _CreateCBVHeap(
ID3D12Device* d3d12Device,
uint32_t elemCount,
uint32_t inputElemCount,
uint32_t outputElemCount,
bool isFP16Data,
UINT& descriptorSize
) noexcept;
Expand All @@ -37,6 +39,7 @@ class DirectMLInferenceBackend : public InferenceBackendBase {
bool _CalcCommandLists(
ID3D12Device* d3d12Device,
SIZE inputSize,
SIZE outputSize,
UINT descriptorSize
) noexcept;

Expand Down
1 change: 1 addition & 0 deletions src/Magpie.Core/InferenceBackendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class InferenceBackendBase {

virtual bool Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
Expand Down
22 changes: 19 additions & 3 deletions src/Magpie.Core/OnnxEffectDrawer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ OnnxEffectDrawer::OnnxEffectDrawer() {}

OnnxEffectDrawer::~OnnxEffectDrawer() {}

static bool ReadJson(const rapidjson::Document& doc, std::string& modelPath, std::string& backend) noexcept {
static bool ReadJson(
const rapidjson::Document& doc,
std::string& modelPath,
uint32_t& scale,
std::string& backend
) noexcept {
if (!doc.IsObject()) {
Logger::Get().Error("根元素不是 Object");
return false;
Expand All @@ -32,6 +37,16 @@ static bool ReadJson(const rapidjson::Document& doc, std::string& modelPath, std
modelPath = node->value.GetString();
}

{
auto node = root.FindMember("scale");
if (node == root.MemberEnd() || !node->value.IsUint()) {
Logger::Get().Error("解析 scale 失败");
return false;
}

scale = node->value.GetUint();
}

{
auto node = root.FindMember("backend");
if (node == root.MemberEnd() || !node->value.IsString()) {
Expand Down Expand Up @@ -62,6 +77,7 @@ bool OnnxEffectDrawer::Initialize(
}

std::string modelPath;
uint32_t scale = 1;
std::string backend;
{
rapidjson::Document doc;
Expand All @@ -71,7 +87,7 @@ bool OnnxEffectDrawer::Initialize(
return false;
}

if (!ReadJson(doc, modelPath, backend)) {
if (!ReadJson(doc, modelPath, scale, backend)) {
Logger::Get().Error("ReadJson 失败");
return false;
}
Expand All @@ -90,7 +106,7 @@ bool OnnxEffectDrawer::Initialize(
}

std::wstring modelPathW = StrUtils::UTF8ToUTF16(modelPath);
if (!_inferenceBackend->Initialize(modelPathW.c_str(), deviceResources, descriptorStore, *inOutTexture, inOutTexture)) {
if (!_inferenceBackend->Initialize(modelPathW.c_str(), scale, deviceResources, descriptorStore, *inOutTexture, inOutTexture)) {
return false;
}

Expand Down

0 comments on commit a5726c7

Please sign in to comment.