Skip to content

Commit

Permalink
[BE][MPS] Add MPS to clang format (pytorch#96562)
Browse files Browse the repository at this point in the history
I'm getting tired of asking to add space after if and all that jazz, so let's linter do that.
Add section for Objective-C language, where column with is extended to 120 characters and `AlignAfterOpenBracket` is set to `Align`

All `.mm` changes in this PR are made by running linter as follows:
```
lintrunner --take CLANGFORMAT --all-files --apply-patches
```

Pull Request resolved: pytorch#96562
Approved by: https://github.com/seemethere, https://github.com/janeyx99, https://github.com/ZainRizvi, https://github.com/izaitsevfb, https://github.com/PaliC, https://github.com/albanD
  • Loading branch information
malfet authored and pytorchmergebot committed Mar 10, 2023
1 parent a7689e7 commit 4242e69
Show file tree
Hide file tree
Showing 48 changed files with 8,308 additions and 9,148 deletions.
10 changes: 7 additions & 3 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
Expand All @@ -85,4 +82,11 @@ SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
---
Language: ObjC
ColumnLimit: 120
AlignAfterOpenBracket: Align
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
...
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ init_command = [
code = 'CLANGFORMAT'
include_patterns = [
'aten/src/ATen/*.h',
'aten/src/ATen/mps/**/*.mm',
'aten/src/ATen/native/mps/**/*.mm',
'aten/src/ATen/native/vulkan/**/*.h',
'aten/src/ATen/native/vulkan/**/*.cpp',
'c10/**/*.h',
Expand Down
276 changes: 153 additions & 123 deletions aten/src/ATen/mps/MPSAllocator.mm

Large diffs are not rendered by default.

83 changes: 45 additions & 38 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

#include <c10/util/CallOnce.h>

#include <ATen/mps/IndexKernels.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/IndexKernels.h>

namespace at {
namespace mps {
Expand All @@ -23,35 +23,39 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
}

MPSDevice* MPSDevice::getInstance() {
c10::call_once(mpsdev_init, [] {
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
});
c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
return mps_device.get();
}

id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions *options = [MTLCompileOptions new];
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled: YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
options: options
error: &error];
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled:YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}

id<MTLFunction> indexFunction = nil;
if (constantValues) {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]
constantValues: constantValues
error: &error] autorelease];
indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]
constantValues:constantValues
error:&error] autorelease];
} else {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
indexFunction =
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
}

TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]);
TORCH_CHECK(indexFunction,
"Failed to create specialized function state object: ",
kernel,
", error: ",
[[error description] UTF8String]);

return indexFunction;
}
Expand All @@ -63,49 +67,52 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
_mtl_indexing_library = nil;
}

MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 12.3+
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");

if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor:
recurrentWeight:
inputWeight:
bias:
initState:
initCell:
descriptor:
name:)] == NO) {
if ([mpsCD instancesRespondToSelector:@selector
(LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) {
return;
}

NSArray* devices = [MTLCopyAllDevices() autorelease];
for (unsigned long i = 0 ; i < [devices count] ; i++) {
id<MTLDevice> device = devices[i];
if(![device isLowPower]) { // exclude Intel GPUs
for (unsigned long i = 0; i < [devices count]; i++) {
id<MTLDevice> device = devices[i];
if (![device isLowPower]) { // exclude Intel GPUs
_mtl_device = [device retain];
break;
}
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);

}

bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
id mpsCD = NSClassFromString(@"MPSGraph");
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES;
static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector(
sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:
axis:name:)] == YES;
static bool _macos_13_1_plus =
[mpsCD instancesRespondToSelector:@selector
(sampleGridWithSourceTensor:
coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode
:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus =
[mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_3_plus = [_mtl_device respondsToSelector:@selector(maximumConcurrentCompilationTaskCount)];

switch (version) {
case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus;
default: return false;
case MacOSVersion::MACOS_VER_13_0_PLUS:
return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS:
return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS:
return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS:
return _macos_13_3_plus;
default:
return false;
}
}

Expand Down
44 changes: 25 additions & 19 deletions aten/src/ATen/mps/MPSFallback.mm
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,46 @@

namespace at {

void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
{
TORCH_WARN_ONCE("The operator '", op.schema().operator_name(), "' is not currently supported ",
void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_WARN_ONCE("The operator '",
op.schema().operator_name(),
"' is not currently supported ",
"on the MPS backend and will fall back to run on the CPU.",
" This may have performance implications.");
native::cpu_fallback(op, stack);
}

void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
{
TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not currently implemented ",
void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack){TORCH_CHECK_NOT_IMPLEMENTED(
false,
"The operator '",
op.schema().operator_name(),
"' is not currently implemented ",
"for the MPS device. If you want this op to be added in priority during the prototype ",
"phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ",
"As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
"to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
"on MPS.")
}

"on MPS.")}

// This dispatch should never be called for tensor on MPS but is frequently called
// If one of them are on CPU
Tensor slow_conv2d_forward_mps(
const Tensor &self,
const Tensor &weight,
IntArrayRef kernel_size,
const c10::optional<Tensor> &bias,
IntArrayRef stride,
IntArrayRef padding) {
TORCH_CHECK(self.device() == weight.device(), __func__, ": input(device='", self.device(), "') and weight(device=", weight.device(), "') must be on the same device");
TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
Tensor slow_conv2d_forward_mps(const Tensor& self,
const Tensor& weight,
IntArrayRef kernel_size,
const c10::optional<Tensor>& bias,
IntArrayRef stride,
IntArrayRef padding) {
TORCH_CHECK(self.device() == weight.device(),
__func__,
": input(device='",
self.device(),
"') and weight(device=",
weight.device(),
"') must be on the same device");
TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
}

TORCH_LIBRARY_IMPL(_, MPS, m) {
static const char *enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
static const char* enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
if (!enable_mps_fallback || std::stoi(enable_mps_fallback) == 0) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&mps_error_fallback>());
} else {
Expand Down
8 changes: 5 additions & 3 deletions aten/src/ATen/mps/MPSGeneratorImpl.mm
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ Generator createMPSGenerator(uint64_t seed_val) {
} // namespace mps

MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in)
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
data_({.seed = seed_in}), engine_(seed_in, 0, 0) { }
: c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)},
data_({.seed = seed_in}),
engine_(seed_in, 0, 0) {}

void MPSGeneratorImpl::set_current_seed(uint64_t seed) {
data_.seed = seed;
Expand Down Expand Up @@ -60,7 +61,8 @@ Generator createMPSGenerator(uint64_t seed_val) {
static const size_t seed_size = sizeof(uint64_t);
static const size_t total_size = states_size + seed_size;

auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto state_tensor = at::detail::empty_cpu(
{(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed();
memcpy(rng_state, this->data_.state.data(), states_size);
Expand Down
84 changes: 37 additions & 47 deletions aten/src/ATen/mps/MPSGuardImpl.mm
Original file line number Diff line number Diff line change
@@ -1,57 +1,47 @@
// Copyright © 2022 Apple Inc.

#include <ATen/mps/MPSGuardImpl.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSGuardImpl.h>

namespace at {
namespace mps {

void MPSGuardImpl::createEvent(
mpsEvent_t* event,
const EventFlag flag) const {
}

void MPSGuardImpl::destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept {
if (!event) return;
auto mps_event = static_cast<mpsEvent_t>(event);
mps_event->~MPSEvent();

}

void MPSGuardImpl::record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const {

TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");

auto mps_event = static_cast<mpsEvent_t>(*event);
MPSStream mps_stream{stream};
mps_event->recordEvent(true);
}

void MPSGuardImpl::block(
void* event,
const Stream& stream) const {

auto mps_event = static_cast<mpsEvent_t>(event);
MPSStream mps_stream{stream};

mps_event->waitForEvent(true);
}

bool MPSGuardImpl::queryEvent(void* event) const {
auto mps_event = static_cast<mpsEvent_t>(event);
return mps_event->queryEvent();
}
void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {}

void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept {
if (!event)
return;
auto mps_event = static_cast<mpsEvent_t>(event);
mps_event->~MPSEvent();
}

void MPSGuardImpl::record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const {
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");

auto mps_event = static_cast<mpsEvent_t>(*event);
MPSStream mps_stream{stream};
mps_event->recordEvent(true);
}

void MPSGuardImpl::block(void* event, const Stream& stream) const {
auto mps_event = static_cast<mpsEvent_t>(event);
MPSStream mps_stream{stream};

mps_event->waitForEvent(true);
}

bool MPSGuardImpl::queryEvent(void* event) const {
auto mps_event = static_cast<mpsEvent_t>(event);
return mps_event->queryEvent();
}

}
}
Loading

0 comments on commit 4242e69

Please sign in to comment.