Skip to content

Commit

Permalink
Bugfix: Vulkan windows compilation error.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuaikang.wsk authored and xiaying committed Nov 19, 2024
1 parent e460135 commit 47a17f5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
24 changes: 12 additions & 12 deletions source/backend/vulkan/image/execution/VulkanSoftmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
namespace MNN {

struct SoftmaxConstBuffer {
uint N;
uint H;
uint W;
uint C4;
uint CLeft;
uint32_t N;
uint32_t H;
uint32_t W;
uint32_t C4;
uint32_t CLeft;
};

VulkanSoftmax::VulkanSoftmax(const Op* op, Backend* bn, const uint axisIndex) : VulkanBasicExecution(bn) {
VulkanSoftmax::VulkanSoftmax(const Op* op, Backend* bn, const uint32_t axisIndex) : VulkanBasicExecution(bn) {
mAxisIndex = axisIndex;
auto vkBn = (VulkanBackend*)backend();
std::string shaderName = "glsl_softmaxImage_";
Expand Down Expand Up @@ -55,7 +55,7 @@ ErrorCode VulkanSoftmax::onEncode(const std::vector<Tensor*>& inputs, const std:
auto input = inputs[0];
auto output = outputs[0];
auto inputShapeNHWC = VulkanTensor::tensorShapeFormat(input);
std::vector<uint> cpuSoftmaxConstBuffer = {(uint)inputShapeNHWC[0], (uint)inputShapeNHWC[1], (uint)inputShapeNHWC[2], (uint)UP_DIV(inputShapeNHWC[3], 4), (uint)ROUND_UP(inputShapeNHWC[3], 4) - inputShapeNHWC[3]};
std::vector<uint32_t> cpuSoftmaxConstBuffer = {(uint32_t)inputShapeNHWC[0], (uint32_t)inputShapeNHWC[1], (uint32_t)inputShapeNHWC[2], (uint32_t)UP_DIV(inputShapeNHWC[3], 4), (uint32_t)ROUND_UP(inputShapeNHWC[3], 4) - inputShapeNHWC[3]};

{
auto softmaxConst = reinterpret_cast<SoftmaxConstBuffer*>(mSoftmaxConstBuffer->map());
Expand All @@ -69,8 +69,8 @@ ErrorCode VulkanSoftmax::onEncode(const std::vector<Tensor*>& inputs, const std:
}

// N * H * W * C4
uint numTotal = cpuSoftmaxConstBuffer[0] * cpuSoftmaxConstBuffer[1] * cpuSoftmaxConstBuffer[2] * cpuSoftmaxConstBuffer[3];
uint numY = numTotal / cpuSoftmaxConstBuffer[mAxisIndex];
uint32_t numTotal = cpuSoftmaxConstBuffer[0] * cpuSoftmaxConstBuffer[1] * cpuSoftmaxConstBuffer[2] * cpuSoftmaxConstBuffer[3];
uint32_t numY = numTotal / cpuSoftmaxConstBuffer[mAxisIndex];

auto vkOutput = (VulkanTensor*)output->deviceId();
auto vkInput = (VulkanTensor*)input->deviceId();
Expand Down Expand Up @@ -98,7 +98,7 @@ class VulkanSoftmaxCreator : public VulkanBackend::Creator {
Backend* backend) const override {
auto input = inputs[0];

uint dimension = input->dimensions();
uint32_t dimension = input->dimensions();
if (dimension > 4) {
return nullptr;
}
Expand All @@ -109,7 +109,7 @@ class VulkanSoftmaxCreator : public VulkanBackend::Creator {
if (axis < 0) {
axis = input->dimensions() + axis;
}
std::vector<uint> axisMap;
std::vector<uint32_t> axisMap;

if (dimension == 4) {
if (format == MNN_DATA_FORMAT_NCHW) {
Expand All @@ -130,7 +130,7 @@ class VulkanSoftmaxCreator : public VulkanBackend::Creator {
} else {
return nullptr;
}
uint axisIndex = axisMap[axis];
uint32_t axisIndex = axisMap[axis];

return new VulkanSoftmax(op, backend, axisIndex);
}
Expand Down
4 changes: 2 additions & 2 deletions source/backend/vulkan/image/execution/VulkanSoftmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace MNN {
class VulkanSoftmax : public VulkanBasicExecution {
public:
VulkanSoftmax(const Op* op, Backend* bn, const uint axisIndex);
VulkanSoftmax(const Op* op, Backend* bn, const uint32_t axisIndex);
virtual ~VulkanSoftmax();
ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const VulkanCommandPool::Buffer* cmdBuffer) override;
Expand All @@ -25,7 +25,7 @@ class VulkanSoftmax : public VulkanBasicExecution {
std::shared_ptr<VulkanBuffer> mSoftmaxConstBuffer;
const VulkanPipeline* mSoftmaxPipeline;
std::shared_ptr<VulkanLayout::DescriptorSet> mDescriptorSet;
uint mAxisIndex;
uint32_t mAxisIndex;
};

} // namespace MNN
Expand Down

0 comments on commit 47a17f5

Please sign in to comment.