Skip to content

Commit

Permalink
[pt-vulkan] Address CLANGTIDY warnings in api, graph, and impl
Browse files Browse the repository at this point in the history
…folders (pytorch#116431)

## Context

**Currently, `*.h` and `*.cpp` produces many lint warnings/errors from `clang-tidy` in the Meta internal Phabricator mirror**. These changes address all the lint warnings in the `api`, `graph`, and `impl` folders in preparation for upcoming planned work.

## Review Guide

* Most changes are the result of automatically applied patches from `clang-tidy`
  * However, some warnings had to be manually addressed
  * There should be no functional changes
* Many of the `clang-tidy` warnings arose from the `facebook-hte-BadMemberName` rule which checks for compliance with variable naming rules from Meta's internal C++ style guide
  * However, the rest of the ATen codebase does not conform to this rule, and PyTorch Vulkan was written to be consisten with ATen's naming conventions; thus, to stay consistent with the rest of ATen, this rule is disabled wherever relevant using `// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName`
* Lint was disabled entirely for`vulkan_api_test.cpp` since there are too many warnings to address at the moment. Addressing all of them will be a small project of its own; thus, in the interim lint will be disabled to reduce distracting signals for developers.

Internal:

## Notes for Internal Reviewers

This diff was largely created with

```
cd ~/fbsource/xplat/caffe2/aten/src/ATen/native/vulkan
arc lint -e extra -a --take CLANGTIDY * 2>&1 | tee ~/scratch/lint.txt
```

The above command automatically applied patches suggested by `clang-tidy`, and the rest of the warnings were addressed manually.

To disable `facebook-hte-BadMemberName`, I found that disabling it via a `.clang-tidy` file didn't work with `arc lint`, and the only way that worked was through the adding a comment

```
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
```

Differential Revision: [D50336057](https://our.internmc.facebook.com/intern/diff/D50336057/)
Pull Request resolved: pytorch#116431
Approved by: https://github.com/GregoryComer, https://github.com/kirklandsign
  • Loading branch information
SS-JIA authored and pytorchmergebot committed Dec 27, 2023
1 parent bbe3261 commit 8d84b50
Show file tree
Hide file tree
Showing 35 changed files with 246 additions and 208 deletions.
23 changes: 21 additions & 2 deletions aten/src/ATen/native/vulkan/VulkanGuardImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
namespace at {
namespace detail {

namespace {

struct VulkanGuardImpl final : public c10::impl::DeviceGuardImplInterface {
VulkanGuardImpl() {}
VulkanGuardImpl() = default;

// NOLINTNEXTLINE
explicit VulkanGuardImpl(DeviceType t) {
TORCH_INTERNAL_ASSERT(t == DeviceType::Vulkan);
}
Expand All @@ -25,14 +28,17 @@ struct VulkanGuardImpl final : public c10::impl::DeviceGuardImplInterface {
// no-op
}
void uncheckedSetDevice(Device d) const noexcept override {
(void)d;
// no-op
}
Stream getStream(Device d) const noexcept override {
(void)d;
// no-op
return Stream(Stream::DEFAULT, Device(DeviceType::Vulkan, -1));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
(void)s;
// no-op
return Stream(Stream::DEFAULT, Device(DeviceType::Vulkan, -1));
}
Expand All @@ -46,18 +52,31 @@ struct VulkanGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
(void)event;
(void)stream;
(void)device_index;
(void)flag;
TORCH_CHECK(false, "VULKAN backend doesn't support events.");
}
void block(void* event, const Stream& stream) const override {
(void)event;
(void)stream;
TORCH_CHECK(false, "VULKAN backend doesn't support events.")
}
bool queryEvent(void* event) const override {
(void)event;
TORCH_CHECK(false, "VULKAN backend doesn't support events.")
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {}
const noexcept override {
(void)event;
(void)device_index;
// no-op
}
};

} // namespace

C10_REGISTER_GUARD_IMPL(Vulkan, VulkanGuardImpl);

} // namespace detail
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
}

bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
(void)memory_format;
return true;
}

Expand Down
30 changes: 14 additions & 16 deletions aten/src/ATen/native/vulkan/api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@
#include <bitset>
#include <iomanip>
#include <sstream>
#include <utility>

namespace at {
namespace native {
namespace vulkan {
namespace api {

PhysicalDevice::PhysicalDevice(const VkPhysicalDevice physical_device_handle)
PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
: handle(physical_device_handle),
properties{},
memory_properties{},
queue_families{},
num_compute_queues(0),
has_unified_memory(false),
has_timestamps(false),
timestamp_period(0) {
has_timestamps(properties.limits.timestampComputeAndGraphics),
timestamp_period(properties.limits.timestampPeriod) {
// Extract physical device properties
vkGetPhysicalDeviceProperties(handle, &properties);
vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties);

has_timestamps = properties.limits.timestampComputeAndGraphics;
timestamp_period = properties.limits.timestampPeriod;

// Check if there are any memory types have both the HOST_VISIBLE and the
// DEVICE_LOCAL property flags
const VkMemoryPropertyFlags unified_memory_flags =
Expand Down Expand Up @@ -101,7 +99,7 @@ VkDevice create_logical_device(
for (const uint32_t family_i :
c10::irange(physical_device.queue_families.size())) {
const VkQueueFamilyProperties& queue_properties =
physical_device.queue_families[family_i];
physical_device.queue_families.at(family_i);
// Check if this family has compute capability
if (queue_properties.queueFlags & VK_QUEUE_COMPUTE_BIT) {
const uint32_t queues_to_init =
Expand Down Expand Up @@ -159,7 +157,7 @@ VkDevice create_logical_device(
nullptr, // pEnabledFeatures
};

VkDevice handle;
VkDevice handle = nullptr;
VK_CHECK(vkCreateDevice(
physical_device.handle, &device_create_info, nullptr, &handle));

Expand All @@ -172,7 +170,7 @@ VkDevice create_logical_device(
for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
VkQueue queue_handle = VK_NULL_HANDLE;
VkQueueFlags flags =
physical_device.queue_families[queue_idx.first].queueFlags;
physical_device.queue_families.at(queue_idx.first).queueFlags;
vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
// Initial usage value
Expand Down Expand Up @@ -243,7 +241,7 @@ std::string get_queue_family_properties_str(const VkQueueFlags flags) {
// DeviceHandle
//

DeviceHandle::DeviceHandle(const VkDevice device) : handle_(device) {}
DeviceHandle::DeviceHandle(VkDevice device) : handle_(device) {}

DeviceHandle::DeviceHandle(DeviceHandle&& other) noexcept
: handle_(other.handle_) {
Expand All @@ -262,11 +260,11 @@ DeviceHandle::~DeviceHandle() {
//

Adapter::Adapter(
const VkInstance instance,
const PhysicalDevice& physical_device,
VkInstance instance,
PhysicalDevice physical_device,
const uint32_t num_queues)
: queue_usage_mutex_{},
physical_device_(physical_device),
physical_device_(std::move(physical_device)),
queues_{},
queue_usage_{},
queue_mutexes_{},
Expand Down Expand Up @@ -313,8 +311,8 @@ void Adapter::return_queue(Adapter::Queue& compute_queue) {

void Adapter::submit_cmd(
const Adapter::Queue& device_queue,
const VkCommandBuffer cmd,
const VkFence fence) {
VkCommandBuffer cmd,
VkFence fence) {
const VkSubmitInfo submit_info{
VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
nullptr, // pNext
Expand All @@ -336,7 +334,7 @@ void Adapter::submit_cmd(
void Adapter::submit_cmds(
const Adapter::Queue& device_queue,
const std::vector<VkCommandBuffer>& cmds,
const VkFence fence) {
VkFence fence) {
const VkSubmitInfo submit_info{
VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
nullptr, // pNext
Expand Down
16 changes: 9 additions & 7 deletions aten/src/ATen/native/vulkan/api/Adapter.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/Common.h>
Expand Down Expand Up @@ -30,12 +32,12 @@ struct PhysicalDevice final {
bool has_timestamps;
float timestamp_period;

explicit PhysicalDevice(const VkPhysicalDevice);
explicit PhysicalDevice(VkPhysicalDevice);
};

class DeviceHandle final {
public:
explicit DeviceHandle(const VkDevice device);
explicit DeviceHandle(VkDevice device);

DeviceHandle(const DeviceHandle&) = delete;
DeviceHandle& operator=(const DeviceHandle&) = delete;
Expand Down Expand Up @@ -80,8 +82,8 @@ class DeviceHandle final {
class Adapter final {
public:
explicit Adapter(
const VkInstance instance,
const PhysicalDevice& physical_device,
VkInstance instance,
PhysicalDevice physical_device,
const uint32_t num_queues);

Adapter(const Adapter&) = delete;
Expand Down Expand Up @@ -185,13 +187,13 @@ class Adapter final {

void submit_cmd(
const Queue&,
const VkCommandBuffer,
const VkFence fence = VK_NULL_HANDLE);
VkCommandBuffer,
VkFence fence = VK_NULL_HANDLE);

void submit_cmds(
const Adapter::Queue&,
const std::vector<VkCommandBuffer>&,
const VkFence fence = VK_NULL_HANDLE);
VkFence fence = VK_NULL_HANDLE);

// Miscellaneous

Expand Down
24 changes: 11 additions & 13 deletions aten/src/ATen/native/vulkan/api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace api {
//

CommandBuffer::CommandBuffer(
const VkCommandBuffer handle,
VkCommandBuffer handle,
const VkCommandBufferUsageFlags flags)
: handle_(handle),
flags_(flags),
Expand All @@ -24,11 +24,10 @@ CommandBuffer::CommandBuffer(
CommandBuffer::CommandBuffer(CommandBuffer&& other) noexcept
: handle_(other.handle_),
flags_(other.flags_),
state_(other.state_),
state_(CommandBuffer::State::INVALID),
bound_(other.bound_) {
other.handle_ = VK_NULL_HANDLE;
other.bound_.reset();
state_ = CommandBuffer::State::INVALID;
}

CommandBuffer& CommandBuffer::operator=(CommandBuffer&& other) noexcept {
Expand Down Expand Up @@ -75,8 +74,8 @@ void CommandBuffer::end() {
}

void CommandBuffer::bind_pipeline(
const VkPipeline pipeline,
const VkPipelineLayout pipeline_layout,
VkPipeline pipeline,
VkPipelineLayout pipeline_layout,
const utils::uvec3 local_workgroup_size) {
TORCH_CHECK(
state_ == CommandBuffer::State::RECORDING,
Expand All @@ -95,7 +94,7 @@ void CommandBuffer::bind_pipeline(
state_ = CommandBuffer::State::PIPELINE_BOUND;
}

void CommandBuffer::bind_descriptors(const VkDescriptorSet descriptors) {
void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
TORCH_CHECK(
state_ == CommandBuffer::State::PIPELINE_BOUND,
"Vulkan CommandBuffer: called bind_descriptors() on a command buffer whose state "
Expand Down Expand Up @@ -317,9 +316,8 @@ void CommandBuffer::copy_buffer_to_texture(
state_ = CommandBuffer::State::RECORDING;
}

void CommandBuffer::write_timestamp(
const VkQueryPool querypool,
const uint32_t idx) const {
void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx)
const {
TORCH_CHECK(
state_ == CommandBuffer::State::RECORDING,
"Vulkan CommandBuffer: called write_timestamp() on a command buffer whose state "
Expand All @@ -330,7 +328,7 @@ void CommandBuffer::write_timestamp(
}

void CommandBuffer::reset_querypool(
const VkQueryPool querypool,
VkQueryPool querypool,
const uint32_t first_idx,
const uint32_t count) const {
TORCH_CHECK(
Expand All @@ -347,7 +345,7 @@ VkCommandBuffer CommandBuffer::get_submit_handle(const bool final_use) {
"Vulkan CommandBuffer: called begin() on a command buffer whose state "
"is not READY.");

const VkCommandBuffer handle = handle_;
VkCommandBuffer handle = handle_;

if (!is_reusable() || final_use) {
invalidate();
Expand All @@ -362,7 +360,7 @@ VkCommandBuffer CommandBuffer::get_submit_handle(const bool final_use) {
//

CommandPool::CommandPool(
const VkDevice device,
VkDevice device,
const uint32_t queue_family_idx,
const CommandPoolConfig& config)
: device_(device),
Expand Down Expand Up @@ -398,7 +396,7 @@ CommandBuffer CommandPool::get_new_cmd(bool reusable) {
// No-ops if there are command buffers available
allocate_new_batch(config_.cmdPoolBatchSize);

const VkCommandBuffer handle = buffers_[in_use_];
VkCommandBuffer handle = buffers_[in_use_];

VkCommandBufferUsageFlags cmd_flags = 0u;
if (!reusable) {
Expand Down
22 changes: 8 additions & 14 deletions aten/src/ATen/native/vulkan/api/Command.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/Common.h>
Expand All @@ -16,9 +18,7 @@ namespace api {

class CommandBuffer final {
public:
explicit CommandBuffer(
const VkCommandBuffer,
const VkCommandBufferUsageFlags);
explicit CommandBuffer(VkCommandBuffer, const VkCommandBufferUsageFlags);

CommandBuffer(const CommandBuffer&) = delete;
CommandBuffer& operator=(const CommandBuffer&) = delete;
Expand Down Expand Up @@ -80,11 +80,8 @@ class CommandBuffer final {
void begin();
void end();

void bind_pipeline(
const VkPipeline,
const VkPipelineLayout,
const utils::uvec3);
void bind_descriptors(const VkDescriptorSet);
void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
void bind_descriptors(VkDescriptorSet);

void insert_barrier(const PipelineBarrier& pipeline_barrier);
void dispatch(const utils::uvec3&);
Expand Down Expand Up @@ -117,8 +114,8 @@ class CommandBuffer final {
const api::utils::uvec3&,
const api::utils::uvec3&);

void write_timestamp(const VkQueryPool, const uint32_t) const;
void reset_querypool(const VkQueryPool, const uint32_t, const uint32_t) const;
void write_timestamp(VkQueryPool, const uint32_t) const;
void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const;

VkCommandBuffer get_submit_handle(const bool final_use = false);

Expand All @@ -134,10 +131,7 @@ struct CommandPoolConfig final {

class CommandPool final {
public:
explicit CommandPool(
const VkDevice,
const uint32_t,
const CommandPoolConfig&);
explicit CommandPool(VkDevice, const uint32_t, const CommandPoolConfig&);

CommandPool(const CommandPool&) = delete;
CommandPool& operator=(const CommandPool&) = delete;
Expand Down
Loading

0 comments on commit 8d84b50

Please sign in to comment.