diff --git a/.clang-format b/.clang-format index 73304266bd671d..f789a97304fc6d 100644 --- a/.clang-format +++ b/.clang-format @@ -60,9 +60,6 @@ MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 NamespaceIndentation: None -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: false PenaltyBreakBeforeFirstCallParameter: 1 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 @@ -85,4 +82,11 @@ SpacesInSquareBrackets: false Standard: Cpp11 TabWidth: 8 UseTab: Never +--- +Language: ObjC +ColumnLimit: 120 +AlignAfterOpenBracket: Align +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false ... diff --git a/.lintrunner.toml b/.lintrunner.toml index 2642a69e3f2442..574cf4d244000b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -49,6 +49,8 @@ init_command = [ code = 'CLANGFORMAT' include_patterns = [ 'aten/src/ATen/*.h', + 'aten/src/ATen/mps/**/*.mm', + 'aten/src/ATen/native/mps/**/*.mm', 'aten/src/ATen/native/vulkan/**/*.h', 'aten/src/ATen/native/vulkan/**/*.cpp', 'c10/**/*.h', diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 274c193c5c6f3a..eeb62b1ba29d8a 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -1,10 +1,10 @@ // Copyright © 2022 Apple Inc. +#include +#include #include #include #include -#include -#include #include namespace at { @@ -19,25 +19,26 @@ void MPSHeapAllocatorImpl::init_allocator() { // debug verbosity flags (see DebugVerbosity enum) - static const char *verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR"); + static const char* verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR"); m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT; - static const char *high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO"); - const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : - default_high_watermark_ratio; + static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO"); + const double high_watermark_ratio = + high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio; setHighWatermarkRatio(high_watermark_ratio); - const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : - default_low_watermark_ratio_discrete; - static const char *low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO"); - const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio; + const double default_low_watermark_ratio = + m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete; + static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO"); + const double low_watermark_ratio = + low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio; setLowWatermarkRatio(low_watermark_ratio); } void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) { TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio); - m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits::max() : - static_cast(ratio * (double)max_device_size()); + m_max_total_allowed_size = + (ratio == 0.0) ? std::numeric_limits::max() : static_cast(ratio * (double)max_device_size()); if (m_debug_verbosity & DebugVerbosity::PROFILING) { std::cerr << "\nHigh watermark memory allocation limit: " << (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n"; @@ -47,11 +48,12 @@ void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) { // used for comparison with lower_watermark_ratio - const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio; + const double high_watermark_limit = + m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio; TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio); // we use this to detect if there's memory pressure - m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits::max() : - static_cast(ratio * (double)max_device_size()); + m_low_watermark_limit = + (ratio == 0.0) ? std::numeric_limits::max() : static_cast(ratio * (double)max_device_size()); if (m_debug_verbosity & DebugVerbosity::PROFILING) { std::cerr << "Low watermark memory allocation limit: " << (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n"; @@ -61,7 +63,7 @@ HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) { BufferPool& pool = *params.pool; - HeapBlock *heap_block = nullptr; + HeapBlock* heap_block = nullptr; HeapBlock search_key(params.size()); auto it = pool.heaps.lower_bound(&search_key); @@ -69,10 +71,8 @@ heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage); if (heap_block) { if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) { - std::cerr << "\nAllocated " - << ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ") - << " heap #" << heap_block->heap_id - << " of size " << format_size(heap_block->size.total) + std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared " : "private ") << " heap #" + << heap_block->heap_id << " of size " << format_size(heap_block->size.total) << " (#heaps: " << (pool.heaps.size() + 1) << ", current allocated: " << format_size(current_allocated_size()) << ")\n"; } @@ -91,7 +91,7 @@ current_allocated_size() + params.size() > m_max_total_allowed_size) { return false; } - HeapBlock *heap = get_free_heap(params); + HeapBlock* heap = get_free_heap(params); if (!heap) { return false; // this will cause releasing pool buffers to free up memory } @@ -109,17 +109,14 @@ pool.n_buffers++; if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) && - (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { - std::cerr << "Allocated " - << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") - << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") - << " buffer #" << params.buffer_block->buf_id - << " of size " << format_size(params.size()) - << " at " << params.buffer_block->buffer - << " from heap #" << heap->heap_id + (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { + std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") + << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #" + << params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at " + << params.buffer_block->buffer << " from heap #" << heap->heap_id << " (requested: " << format_size(params.requested_size) - << ", heap: " << format_size(heap->size.available) - << ", total: " << format_size(m_total_allocated_memory) << ")\n"; + << ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory) + << ")\n"; } return true; } @@ -158,8 +155,8 @@ // this will skip unnecessary garbage collection as we'll reuse the newly released space params.has_memory_pressure = false; } else if (params.has_memory_pressure) { - // the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap container) - // in allocator, and ARC will later free up its backing memory when the busy command buffer finishes. + // the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap + // container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes. release_buffer(buffer_block, true); } else { // only if there's no memory pressure, we'll reuse the oversized buffer @@ -176,16 +173,13 @@ pool.available_size -= params.buffer_block->size; if ((m_debug_verbosity & DebugVerbosity::RECYCLES) && - (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { - std::cerr << "Reusing " - << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") - << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") - << " buffer #" << params.buffer_block->buf_id - << " of size " << format_size(params.buffer_block->size) - << " at " << params.buffer_block->buffer - << " (requested: " << format_size(params.requested_size) - << ", use#: " << params.buffer_block->use_count + 1 - << ", retain#: " << params.buffer_block->retainCount() << ")\n"; + (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { + std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") + << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #" + << params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at " + << params.buffer_block->buffer << " (requested: " << format_size(params.requested_size) + << ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount() + << ")\n"; } return true; } @@ -214,7 +208,8 @@ alloc_buffer(params) || // Callbacks might release more memory (eg. by forcing a GC in the host language) thus // we can retry getting a free buffer in the pool, before trying to alloc again. - (trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && get_free_buffer(params)) || + (trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && + get_free_buffer(params)) || // Free enough available cached blocks to satisfy alloc and retry alloc. (release_available_cached_buffers(params) && alloc_buffer(params)) || // Free all cached buffers and retry alloc. @@ -229,16 +224,30 @@ // chunk of requested size couldn't be found. if (!block_found || !buffer_block) { if (m_high_watermark_ratio > 0.0) { - TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory), - ", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory), - ", max allowed: ", format_size(m_max_total_allowed_size), "). Tried to allocate ", format_size(alloc_size), - " on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), - " pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."); + TORCH_CHECK( + false, + "MPS backend out of memory (MPS allocated: ", + format_size(m_total_allocated_memory), + ", other allocations: ", + format_size(current_allocated_size() - m_total_allocated_memory), + ", max allowed: ", + format_size(m_max_total_allowed_size), + "). Tried to allocate ", + format_size(alloc_size), + " on ", + ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), + " pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."); } else { - TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory), - ", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory), - "). Tried to allocate ", format_size(alloc_size), - " on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), " pool."); + TORCH_CHECK(false, + "MPS backend out of memory (MPS allocated: ", + format_size(m_total_allocated_memory), + ", other allocations: ", + format_size(current_allocated_size() - m_total_allocated_memory), + "). Tried to allocate ", + format_size(alloc_size), + " on ", + ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), + " pool."); } } buffer_block->in_use = true; @@ -270,7 +279,7 @@ } bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) { - HeapBlock *heap_block = buffer_block->heap; + HeapBlock* heap_block = buffer_block->heap; BufferPool& pool = *heap_block->pool; m_total_allocated_memory -= buffer_block->size; pool.allocated_size -= buffer_block->size; @@ -283,13 +292,10 @@ uint32_t retainCount = heap_block->releaseMTLBuffer(buffer_block->buffer); if ((m_debug_verbosity & DebugVerbosity::RELEASES) && - (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { - std::cerr << "Released buffer #" << buffer_block->buf_id - << " of size " << format_size(buffer_block->size) - << " from heap #" << heap_block->heap_id - << " (heap size: " << format_size(heap_block->size.available) - << ", use#: " << buffer_block->use_count - << ", retain#: " << retainCount + (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { + std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size) + << " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available) + << ", use#: " << buffer_block->use_count << ", retain#: " << retainCount << ", gc#: " << buffer_block->gc_count << ")\n"; } delete buffer_block; @@ -298,10 +304,9 @@ pool.heaps_pending_update.erase(heap_block); retainCount = heap_block->releaseMTLHeap(); if (m_debug_verbosity & DebugVerbosity::RELEASES) { - std::cerr << "Released heap #" << heap_block->heap_id - << " of size " << format_size(heap_block->size.total) - << " (current allocated: " << format_size(current_allocated_size()) - << ", retain#: " << retainCount << ")\n"; + std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total) + << " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount + << ")\n"; } delete heap_block; return true; @@ -312,7 +317,7 @@ if (retainCount > 1) { pool.heaps_pending_update.insert(heap_block); m_mutex.unlock(); - m_stream->addCompletedHandler(^(id ) { + m_stream->addCompletedHandler(^(id) { std::lock_guard lock(m_mutex); // check if the heap block still exists if (pool.heaps_pending_update.find(heap_block) != pool.heaps_pending_update.end()) { @@ -333,13 +338,11 @@ return; } if ((m_debug_verbosity & DebugVerbosity::RELEASES)) { - std::cerr << "Releasing " << pool.buffers.size() - << " buffers from " - << ((pool.usage & UsageFlags::SMALL ) ? "small " : "large ") + std::cerr << "Releasing " << pool.buffers.size() << " buffers from " + << ((pool.usage & UsageFlags::SMALL) ? "small " : "large ") << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") << ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "") - << " pool (total size: " << format_size(pool.allocated_size) - << ", #buffers: " << pool.n_buffers << ")\n"; + << " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n"; } auto it = pool.buffers.begin(); while (it != pool.buffers.end()) { @@ -381,10 +384,8 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() { if (m_debug_verbosity >= DebugVerbosity::PROFILING) { - std::cerr << "Attempting to release cached buffers (MPS allocated: " - << format_size(m_total_allocated_memory) - << ", other allocations: " - << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n"; + std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory) + << ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n"; } // before releasing the buffers make sure the command buffer has finished. // we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers. @@ -445,11 +446,10 @@ } } if (m_debug_verbosity & DebugVerbosity::RELEASES) { - std::cerr << "Garbage collected " << freed_count - << " buffers from large " + std::cerr << "Garbage collected " << freed_count << " buffers from large " << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") - << " pool (total reclaimed: " << format_size(gc_reclaimed) - << ", #buffers: " << pool.buffers.size() << ")\n"; + << " pool (total reclaimed: " << format_size(gc_reclaimed) << ", #buffers: " << pool.buffers.size() + << ")\n"; } } @@ -464,7 +464,7 @@ bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr) { std::lock_guard lock(m_mutex); - BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + BufferBlock* buffer_block = get_allocated_buffer_block(ptr); // it's OK for the buffer_block to not exist yet return buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED); } @@ -487,9 +487,9 @@ ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr) { std::lock_guard lock(m_mutex); - BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + BufferBlock* buffer_block = get_allocated_buffer_block(ptr); if (buffer_block) { - return (ssize_t) buffer_block->requested_size; + return (ssize_t)buffer_block->requested_size; } // -1 indicates the passed buffer pointer wasn't found return -1; @@ -498,7 +498,7 @@ void MPSHeapAllocatorImpl::setBufferShape(void* ptr, const IntArrayRef& shape) { std::lock_guard lock(m_mutex); - BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + BufferBlock* buffer_block = get_allocated_buffer_block(ptr); TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr); // note that the IntArrayRef doesn't own the underlying data, and the backing // memory for shape data must persist as long as the buffer is in use. @@ -509,7 +509,7 @@ IntArrayRef MPSHeapAllocatorImpl::getBufferShape(void* ptr) { std::lock_guard lock(m_mutex); - BufferBlock *buffer_block = get_allocated_buffer_block(ptr); + BufferBlock* buffer_block = get_allocated_buffer_block(ptr); if (buffer_block && buffer_block->shape.size() > 0) { return IntArrayRef{buffer_block->shape}; } @@ -517,7 +517,7 @@ } void MPSHeapAllocatorImpl::free(void* ptr) { - BufferBlock *buffer_block = nullptr; + BufferBlock* buffer_block = nullptr; { std::lock_guard lock(m_mutex); @@ -531,7 +531,7 @@ } // we sync the scalar pool manually with completion handler at the time buffer is // freed when the MPSScalar instance goes our of scope - m_stream->addCompletedHandler(^(id ) { + m_stream->addCompletedHandler(^(id) { std::lock_guard lock(m_mutex); free_buffer(buffer_block); }); @@ -555,10 +555,15 @@ std::ostringstream os; os.precision(2); os << std::fixed; - if (size <= 1024UL) { os << size << " bytes"; } - else if (size <= 1048576UL) { os << ((float) size / 1024.0) << " KB"; } - else if (size <= 1073741824UL) { os << ((float) size / 1048576.0) << " MB"; } - else { os << ((float) size / 1073741824.0) << " GB"; } + if (size <= 1024UL) { + os << size << " bytes"; + } else if (size <= 1048576UL) { + os << ((float)size / 1024.0) << " KB"; + } else if (size <= 1073741824UL) { + os << ((float)size / 1048576.0) << " MB"; + } else { + os << ((float)size / 1073741824.0) << " GB"; + } return os.str(); } @@ -574,16 +579,13 @@ // MPS allocator struct to be registered with Pytorch struct TORCH_API MPSAllocator final : public IMPSAllocator { -public: - explicit MPSAllocator(uint32_t Usage) : - m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) - { + public: + explicit MPSAllocator(uint32_t Usage) + : m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) { if (_getAllocImpl().getDebugVerbosity()) { if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) { - std::cerr << "Initializing " - << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private") - << " heap allocator on " - << (m_has_unified_memory ? "unified" : "discrete") + std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private") + << " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete") << " device memory of size " << _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n"; } @@ -593,34 +595,64 @@ explicit MPSAllocator(uint32_t Usage) : ~MPSAllocator() override { _getAllocImpl().emptyCache(); } - DeleterFnPtr raw_deleter() const override { return &Delete; } + DeleterFnPtr raw_deleter() const override { + return &Delete; + } DataPtr allocate(const size_t nbytes) const override { __block id buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr; - return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; + return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; } // implementation of IMPSAllocator interface - DataPtr allocScalarBufferWithValue(void *value, size_t size) const override { + DataPtr allocScalarBufferWithValue(void* value, size_t size) const override { id buf = _getAllocImpl().allocScalarBufferWithValue(value, size); - return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; - } - bool isSharedBuffer(void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); } - bool isSharedStorageSupported() const override { return m_has_unified_memory; } - void emptyCache() const override { _getAllocImpl().emptyCache(); } - ssize_t getUnalignedBufferSize(void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); } - IntArrayRef getBufferShape(void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); } - void setBufferShape(void* ptr, const IntArrayRef& shape) const override { _getAllocImpl().setBufferShape(ptr, shape); } - size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); } - size_t getCurrentAllocatedMemory() const override { return _getAllocImpl().getCurrentAllocatedMemory(); } - size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); } - ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); } - size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); } - size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); } - void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); } - void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); } - -private: + return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; + } + bool isSharedBuffer(void* ptr) const override { + return _getAllocImpl().isSharedBuffer(ptr); + } + bool isSharedStorageSupported() const override { + return m_has_unified_memory; + } + void emptyCache() const override { + _getAllocImpl().emptyCache(); + } + ssize_t getUnalignedBufferSize(void* ptr) const override { + return _getAllocImpl().getUnalignedBufferSize(ptr); + } + IntArrayRef getBufferShape(void* ptr) const override { + return _getAllocImpl().getBufferShape(ptr); + } + void setBufferShape(void* ptr, const IntArrayRef& shape) const override { + _getAllocImpl().setBufferShape(ptr, shape); + } + size_t getTotalAllocatedMemory() const override { + return _getAllocImpl().getTotalAllocatedMemory(); + } + size_t getCurrentAllocatedMemory() const override { + return _getAllocImpl().getCurrentAllocatedMemory(); + } + size_t getDriverAllocatedMemory() const override { + return _getAllocImpl().getDriverAllocatedMemory(); + } + ssize_t getLowWatermarkValue() const override { + return _getAllocImpl().getLowWatermarkValue(); + } + size_t getLowWatermarkLimit() const override { + return _getAllocImpl().getLowWatermarkLimit(); + } + size_t getHighWatermarkLimit() const override { + return _getAllocImpl().getHighWatermarkLimit(); + } + void setLowWatermarkRatio(double ratio) const override { + _getAllocImpl().setLowWatermarkRatio(ratio); + } + void setHighWatermarkRatio(double ratio) const override { + _getAllocImpl().setHighWatermarkRatio(ratio); + } + + private: bool m_has_unified_memory; uint32_t m_usage; @@ -662,15 +694,13 @@ static void Delete(void* ptr) { // Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we // will be able to use SharedStorageMode for MTLBuffer allocations. This will // avoid extra copies on DataLoading operations. -bool is_pinned_mps(const Tensor& self, c10::optional device) -{ +bool is_pinned_mps(const Tensor& self, c10::optional device) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps()); return at::mps::_getSharedAllocator().isSharedBuffer(self.storage().data()); } // torch.pin_memory() implementation -Tensor _pin_memory_mps(const Tensor& self, c10::optional device) -{ +Tensor _pin_memory_mps(const Tensor& self, c10::optional device) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps()); auto* shared_allocator = at::mps::getIMPSAllocator(true); TORCH_CHECK(shared_allocator, "unable to pin memory on a non-unified memory device"); diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index ecc54e90ae46f6..211530a3802a7c 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -2,10 +2,10 @@ #include +#include +#include #include #include -#include -#include namespace at { namespace mps { @@ -23,9 +23,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de } MPSDevice* MPSDevice::getInstance() { - c10::call_once(mpsdev_init, [] { - mps_device = std::unique_ptr(new MPSDevice()); - }); + c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr(new MPSDevice()); }); return mps_device.get(); } @@ -33,25 +31,31 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); NSError* error = nil; if (!_mtl_indexing_library) { - MTLCompileOptions *options = [MTLCompileOptions new]; - [options setLanguageVersion: getMetalLanguageVersion(_mtl_device)]; - [options setFastMathEnabled: YES]; - _mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding] - options: options - error: &error]; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:getMetalLanguageVersion(_mtl_device)]; + [options setFastMathEnabled:YES]; + _mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders + encoding:NSASCIIStringEncoding] + options:options + error:&error]; TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]); } id indexFunction = nil; if (constantValues) { - indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()] - constantValues: constantValues - error: &error] autorelease]; + indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()] + constantValues:constantValues + error:&error] autorelease]; } else { - indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease]; + indexFunction = + [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease]; } - TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]); + TORCH_CHECK(indexFunction, + "Failed to create specialized function state object: ", + kernel, + ", error: ", + [[error description] UTF8String]); return indexFunction; } @@ -63,49 +67,52 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de _mtl_indexing_library = nil; } -MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) { +MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) { // Check that MacOS 12.3+ version of MPS framework is available // Create the MPSGraph and check method introduced in 12.3+ // which is used by MPS backend. id mpsCD = NSClassFromString(@"MPSGraph"); - if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor: - recurrentWeight: - inputWeight: - bias: - initState: - initCell: - descriptor: - name:)] == NO) { + if ([mpsCD instancesRespondToSelector:@selector + (LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) { return; } NSArray* devices = [MTLCopyAllDevices() autorelease]; - for (unsigned long i = 0 ; i < [devices count] ; i++) { - id device = devices[i]; - if(![device isLowPower]) { // exclude Intel GPUs + for (unsigned long i = 0; i < [devices count]; i++) { + id device = devices[i]; + if (![device isLowPower]) { // exclude Intel GPUs _mtl_device = [device retain]; break; } } TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); - } bool MPSDevice::isMacOS13Plus(MacOSVersion version) const { id mpsCD = NSClassFromString(@"MPSGraph"); - static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES; - static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector( - sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES; - static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES; + static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor: + axis:name:)] == YES; + static bool _macos_13_1_plus = + [mpsCD instancesRespondToSelector:@selector + (sampleGridWithSourceTensor: + coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode + :samplingMode:constantValue:name:)] == YES; + static bool _macos_13_2_plus = + [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES; static bool _macos_13_3_plus = [_mtl_device respondsToSelector:@selector(maximumConcurrentCompilationTaskCount)]; switch (version) { - case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus; - case MacOSVersion::MACOS_VER_13_1_PLUS: return _macos_13_1_plus; - case MacOSVersion::MACOS_VER_13_2_PLUS: return _macos_13_2_plus; - case MacOSVersion::MACOS_VER_13_3_PLUS: return _macos_13_3_plus; - default: return false; + case MacOSVersion::MACOS_VER_13_0_PLUS: + return _macos_13_0_plus; + case MacOSVersion::MACOS_VER_13_1_PLUS: + return _macos_13_1_plus; + case MacOSVersion::MACOS_VER_13_2_PLUS: + return _macos_13_2_plus; + case MacOSVersion::MACOS_VER_13_3_PLUS: + return _macos_13_3_plus; + default: + return false; } } diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm index 1d51a26b18f279..baf2f185562d40 100644 --- a/aten/src/ATen/mps/MPSFallback.mm +++ b/aten/src/ATen/mps/MPSFallback.mm @@ -4,40 +4,46 @@ namespace at { -void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) -{ - TORCH_WARN_ONCE("The operator '", op.schema().operator_name(), "' is not currently supported ", +void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_WARN_ONCE("The operator '", + op.schema().operator_name(), + "' is not currently supported ", "on the MPS backend and will fall back to run on the CPU.", " This may have performance implications."); native::cpu_fallback(op, stack); } -void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) -{ - TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not currently implemented ", +void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack){TORCH_CHECK_NOT_IMPLEMENTED( + false, + "The operator '", + op.schema().operator_name(), + "' is not currently implemented ", "for the MPS device. If you want this op to be added in priority during the prototype ", "phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ", "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ", "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ", - "on MPS.") -} - + "on MPS.")} // This dispatch should never be called for tensor on MPS but is frequently called // If one of them are on CPU -Tensor slow_conv2d_forward_mps( - const Tensor &self, - const Tensor &weight, - IntArrayRef kernel_size, - const c10::optional &bias, - IntArrayRef stride, - IntArrayRef padding) { - TORCH_CHECK(self.device() == weight.device(), __func__, ": input(device='", self.device(), "') and weight(device=", weight.device(), "') must be on the same device"); - TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device"); +Tensor slow_conv2d_forward_mps(const Tensor& self, + const Tensor& weight, + IntArrayRef kernel_size, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding) { + TORCH_CHECK(self.device() == weight.device(), + __func__, + ": input(device='", + self.device(), + "') and weight(device=", + weight.device(), + "') must be on the same device"); + TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device"); } TORCH_LIBRARY_IMPL(_, MPS, m) { - static const char *enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK"); + static const char* enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK"); if (!enable_mps_fallback || std::stoi(enable_mps_fallback) == 0) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&mps_error_fallback>()); } else { diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.mm b/aten/src/ATen/mps/MPSGeneratorImpl.mm index 7eb6b7d987826e..ed7be96c8c743d 100644 --- a/aten/src/ATen/mps/MPSGeneratorImpl.mm +++ b/aten/src/ATen/mps/MPSGeneratorImpl.mm @@ -23,8 +23,9 @@ Generator createMPSGenerator(uint64_t seed_val) { } // namespace mps MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in) - : c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)}, - data_({.seed = seed_in}), engine_(seed_in, 0, 0) { } + : c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)}, + data_({.seed = seed_in}), + engine_(seed_in, 0, 0) {} void MPSGeneratorImpl::set_current_seed(uint64_t seed) { data_.seed = seed; @@ -60,7 +61,8 @@ Generator createMPSGenerator(uint64_t seed_val) { static const size_t seed_size = sizeof(uint64_t); static const size_t total_size = states_size + seed_size; - auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto state_tensor = at::detail::empty_cpu( + {(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); auto rng_state = state_tensor.data_ptr(); auto current_seed = this->current_seed(); memcpy(rng_state, this->data_.state.data(), states_size); diff --git a/aten/src/ATen/mps/MPSGuardImpl.mm b/aten/src/ATen/mps/MPSGuardImpl.mm index 787ef4cae7cd2c..9c204718e50bdf 100644 --- a/aten/src/ATen/mps/MPSGuardImpl.mm +++ b/aten/src/ATen/mps/MPSGuardImpl.mm @@ -1,57 +1,47 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at { namespace mps { - void MPSGuardImpl::createEvent( - mpsEvent_t* event, - const EventFlag flag) const { - } - - void MPSGuardImpl::destroyEvent( - void* event, - const DeviceIndex device_index) const noexcept { - if (!event) return; - auto mps_event = static_cast(event); - mps_event->~MPSEvent(); - - } - - void MPSGuardImpl::record( - void** event, - const Stream& stream, - const DeviceIndex device_index, - const EventFlag flag) const { - - TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), - "Event device index ", - device_index, - " does not match recording stream's device index ", - stream.device_index(), - "."); - - auto mps_event = static_cast(*event); - MPSStream mps_stream{stream}; - mps_event->recordEvent(true); - } - - void MPSGuardImpl::block( - void* event, - const Stream& stream) const { - - auto mps_event = static_cast(event); - MPSStream mps_stream{stream}; - - mps_event->waitForEvent(true); - } - - bool MPSGuardImpl::queryEvent(void* event) const { - auto mps_event = static_cast(event); - return mps_event->queryEvent(); - } +void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {} + +void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept { + if (!event) + return; + auto mps_event = static_cast(event); + mps_event->~MPSEvent(); +} + +void MPSGuardImpl::record(void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const { + TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + auto mps_event = static_cast(*event); + MPSStream mps_stream{stream}; + mps_event->recordEvent(true); +} + +void MPSGuardImpl::block(void* event, const Stream& stream) const { + auto mps_event = static_cast(event); + MPSStream mps_stream{stream}; + + mps_event->waitForEvent(true); +} + +bool MPSGuardImpl::queryEvent(void* event) const { + auto mps_event = static_cast(event); + return mps_event->queryEvent(); +} } } diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm index f1f2d47cf1e6a6..1787bceca982bc 100644 --- a/aten/src/ATen/mps/MPSStream.mm +++ b/aten/src/ATen/mps/MPSStream.mm @@ -1,7 +1,7 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at { namespace mps { @@ -17,9 +17,9 @@ TORCH_CHECK(_stream.device_type() == DeviceType::MPS); _serialQueue = dispatch_queue_create("metal gpu stream", nullptr); _executionDescriptor = [MPSGraphExecutionDescriptor new]; - _executionDescriptor.completionHandler = ^(NSDictionary * resultsDictionary, - NSError * _Nullable error) { }; + _executionDescriptor.completionHandler = + ^(NSDictionary* resultsDictionary, NSError* _Nullable error) { + }; } MPSStream::~MPSStream() { @@ -41,7 +41,7 @@ void MPSStream::synchronize(SyncType syncType) { if (!_commandBuffer) return; - switch(syncType) { + switch (syncType) { case SyncType::NONE: // typically in GPU to GPU copies we won't commit explicitly break; @@ -108,32 +108,34 @@ } void MPSStream::addCompletedHandler(MTLCommandBufferHandler block) { - dispatch_sync(_serialQueue, ^() { + dispatch_sync(_serialQueue, ^() { @autoreleasepool { [commandBuffer() addCompletedHandler:block]; } }); } -void MPSStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) -{ +void MPSStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { TORCH_INTERNAL_ASSERT(length >= offset); - if (length == 0) return; + if (length == 0) + return; dispatch_sync(_serialQueue, ^() { @autoreleasepool { id blitEncoder = [commandBuffer() blitCommandEncoder]; - [blitEncoder fillBuffer:buffer - range:NSMakeRange(offset, length) - value:value]; + [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; [blitEncoder endEncoding]; synchronize(syncType); } }); } -void MPSStream::copy(id srcBuffer, id dstBuffer, - size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) { +void MPSStream::copy(id srcBuffer, + id dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + SyncType syncType) { dispatch_sync(_serialQueue, ^() { @autoreleasepool { id blitEncoder = [commandBuffer() blitCommandEncoder]; @@ -149,10 +151,14 @@ }); } -void MPSStream::copy_and_sync(id srcBuffer, id dstBuffer, size_t length, - size_t srcOffset, size_t dstOffset, bool non_blocking) { - copy(srcBuffer, dstBuffer, length, srcOffset, dstOffset, - !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT); +void MPSStream::copy_and_sync(id srcBuffer, + id dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + bool non_blocking) { + copy( + srcBuffer, dstBuffer, length, srcOffset, dstOffset, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT); } void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { @@ -173,7 +179,7 @@ resultsDictionary:results executionDescriptor:_executionDescriptor]; #endif - }); + }); } //----------------------------------------------------------------- @@ -184,8 +190,7 @@ MPSStream* MPSStreamImpl::getInstance() { if (_stream == nullptr) { - _stream = - new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0)); + _stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS), 0)); } return _stream; } @@ -204,8 +209,8 @@ // MPSEvent //----------------------------------------------------------------- -MPSEvent::MPSEvent(bool deferInitialization) : - is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) { +MPSEvent::MPSEvent(bool deferInitialization) + : is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) { if (!deferInitialization) { initialize(); } @@ -256,8 +261,7 @@ }); } -void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) -{ +void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) { if (!is_initialized) initialize(); dispatch_sync(_stream->queue(), ^() { diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 32084d4b4a3025..c7d4a35b11e3e4 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -1,7 +1,7 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at::native::mps { @@ -28,27 +28,31 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { case ScalarType::Bool: return MPSDataTypeBool; case ScalarType::Double: - TORCH_CHECK_TYPE(false, "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " + TORCH_CHECK_TYPE(false, + "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " "Please use float32 instead.") default: - TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") + TORCH_CHECK_TYPE( + false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") } } // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast to these // types. -MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) { +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const Tensor& input, + bool includesInt64) { MPSDataType dataType = getMPSDataType(input.scalar_type()); - bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); + bool condition = + (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); if (includesInt64) { condition = condition && (dataType != MPSDataTypeInt64); } if (condition) { dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; - return [mpsGraph castTensor:inputTensor - toType:dataType - name:@"castInputTensor"]; + return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; } return inputTensor; } @@ -56,16 +60,18 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast from these // types. -MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) { +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const Tensor& input, + bool includesInt64) { MPSDataType dataType = getMPSDataType(input.scalar_type()); - bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); + bool condition = + (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); if (includesInt64) { condition = condition && (dataType != MPSDataTypeInt64); } if (condition) { - inputTensor = [mpsGraph castTensor:inputTensor - toType:dataType - name:@"castInputTensor"]; + inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; } return inputTensor; } @@ -92,7 +98,8 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Bool: return MPSDataTypeBool; default: - TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") + TORCH_CHECK_TYPE( + false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") } } @@ -148,7 +155,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { NSArray* getTensorAxes(const Tensor& t) { int64_t ndim = t.dim(); auto axes = [NSMutableArray arrayWithCapacity:ndim]; - for (const auto i: c10::irange(ndim)) { + for (const auto i : c10::irange(ndim)) { axes[i] = [NSNumber numberWithInteger:i]; } return axes; @@ -159,7 +166,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { IntArrayRef dimValues = dim.value(); int ndim = dimValues.size(); auto axes = [NSMutableArray arrayWithCapacity:ndim]; - for (const auto i: c10::irange(ndim)) { + for (const auto i : c10::irange(ndim)) { axes[i] = [NSNumber numberWithInteger:dimValues[i]]; } @@ -170,11 +177,11 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { } std::string getMPSShapeString(MPSShape* shape) { - std::string str; - for(NSNumber *elem in shape) { - str += std::to_string(elem.unsignedLongValue) + ","; - } - return str; + std::string str; + for (NSNumber* elem in shape) { + str += std::to_string(elem.unsignedLongValue) + ","; + } + return str; } std::string getArrayRefString(const IntArrayRef s) { @@ -184,25 +191,25 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { } std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) { - std::string str; - // The key format per tensor would look like ":Float32[1,1,1,10]:" - for (const Tensor& tensor: tensors) { - str += ":"; - if (tensor.defined()) { - str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "["; - // if tensor is a scalar - if (tensor.dim() == 0) { - str += "Scalar"; - } else { - const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","]; - str += std::string(ns_shape_key.UTF8String); - } - str += "]"; + std::string str; + // The key format per tensor would look like ":Float32[1,1,1,10]:" + for (const Tensor& tensor : tensors) { + str += ":"; + if (tensor.defined()) { + str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "["; + // if tensor is a scalar + if (tensor.dim() == 0) { + str += "Scalar"; } else { - str += "Undefined"; + const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","]; + str += std::string(ns_shape_key.UTF8String); } + str += "]"; + } else { + str += "Undefined"; } - return str; + } + return str; } MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) { @@ -216,7 +223,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { const NSUInteger C = sizes[1]; const NSUInteger H = sizes[2]; const NSUInteger W = sizes[3]; - return @[@(N), @(H), @(W), @(C)]; + return @[ @(N), @(H), @(W), @(C) ]; } const int sz = sizes.size(); const int sz_ = (sz > 0) ? sz : 1; @@ -232,27 +239,27 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { } void printTensorNDArray(const Tensor& t) { - if (!t.is_mps()) return; - if(t.numel() == 0) return; + if (!t.is_mps()) + return; + if (t.numel() == 0) + return; // Get shape and data type auto selfShape = getMPSShape(t); auto selfDType = getMPSDataType(t.scalar_type()); // Initialize data id selfBuf = getMTLBufferStorage(t); - MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf - shape:selfShape - dataType:selfDType] autorelease]; + MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape + dataType:selfDType] autorelease]; C10_CLANG_DIAGNOSTIC_PUSH() - #if C10_CLANG_HAS_WARNING("-Wobjc-method-access") +#if C10_CLANG_HAS_WARNING("-Wobjc-method-access") C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access") - #endif +#endif [tdata printNDArray]; C10_CLANG_DIAGNOSTIC_POP() } -MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType) -{ +MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) { id buffer = getMTLBufferStorage(tensor); MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer shape:shape @@ -261,16 +268,19 @@ void printTensorNDArray(const Tensor& t) { return [tmpGraphTensorData mpsndarray]; } -Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape, - bool gatherTensorData, MPSDataType dataType) : _tensor(src) -{ +Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, + const Tensor& src, + MPSShape* mpsShape, + bool gatherTensorData, + MPSDataType dataType) + : _tensor(src) { TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = getMTLBufferStorage(src); bool sliceViewTensor = canSliceViewTensor(src, mpsShape); // a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose()) if ((!src.is_contiguous() || (src.storage_offset() && !sliceViewTensor)) && gatherTensorData) { - Tensor emptyShell = Tensor(); + Tensor emptyShell = Tensor(); // use "_tensor" from Placeholder to retain view's output during its usage in other ops _tensor = gatherViewTensor(src, emptyShell); if (!_tensor.has_storage()) { @@ -285,8 +295,9 @@ void printTensorNDArray(const Tensor& t) { // if buffer size is zero in here, it's not a user error. It could be a missing check for // tensor.numel() == 0 in our internal implementations of ops. TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); - const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType : - _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type()); + const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType + : _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) + : getMPSDataType(_tensor.scalar_type()); if (src.is_contiguous() && src.storage_offset() && sliceViewTensor) { _value = getMPSGraphTensorDataForView(src, mpsShape, mpsDataType); @@ -295,34 +306,25 @@ void printTensorNDArray(const Tensor& t) { mpsShape = getMPSShape(_tensor); } - _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf - shape:mpsShape - dataType:mpsDataType] autorelease]; + _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:mpsShape dataType:mpsDataType] autorelease]; } TORCH_INTERNAL_ASSERT(_value); _placeholder = mpsGraphTensor; } -MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, - MPSStream* mpsStream, - const Tensor& tensor) { +MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) { auto mpsShape = getMPSShape(tensor); auto dataType = getMPSDataType(tensor.scalar_type()); - MPSGraphTensorData *result = nil; + MPSGraphTensorData* result = nil; if (tensor.numel() > 0) { id buf = getMTLBufferStorage(tensor); - result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf - shape:mpsShape - dataType:dataType] - autorelease]; + result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] autorelease]; } else { // create empty NDArray - MPSNDArrayDescriptor *desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType - shape:mpsShape]; - MPSNDArray *emptyArray = [[[MPSNDArray alloc] - initWithDevice:mpsStream->device() descriptor:desc] autorelease]; + MPSNDArrayDescriptor* desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsShape]; + MPSNDArray* emptyArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:desc] autorelease]; result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:emptyArray] autorelease]; } assert(result); @@ -332,30 +334,40 @@ void printTensorNDArray(const Tensor& t) { MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { switch (type) { case ScalarType::Double: - case ScalarType::Float: return {.value.f = scalar.to() , .size = sizeof(float) , .type = type}; - case ScalarType::Half: return {.value.h = scalar.to(), .size = sizeof(short) , .type = type}; - case ScalarType::Long: return {.value.i = scalar.to() , .size = sizeof(int64_t), .type = type}; - case ScalarType::Int: return {.value.i = scalar.to() , .size = sizeof(int32_t), .type = type}; - case ScalarType::Short: return {.value.i = scalar.to() , .size = sizeof(int16_t), .type = type}; - case ScalarType::Char: return {.value.i = scalar.to() , .size = sizeof(int8_t) , .type = type}; - case ScalarType::Byte: return {.value.i = scalar.to() , .size = sizeof(uint8_t), .type = type}; - case ScalarType::Bool: return {.value.b = scalar.to() , .size = sizeof(bool) , .type = type}; + case ScalarType::Float: + return {.value.f = scalar.to(), .size = sizeof(float), .type = type}; + case ScalarType::Half: + return {.value.h = scalar.to(), .size = sizeof(short), .type = type}; + case ScalarType::Long: + return {.value.i = scalar.to(), .size = sizeof(int64_t), .type = type}; + case ScalarType::Int: + return {.value.i = scalar.to(), .size = sizeof(int32_t), .type = type}; + case ScalarType::Short: + return {.value.i = scalar.to(), .size = sizeof(int16_t), .type = type}; + case ScalarType::Char: + return {.value.i = scalar.to(), .size = sizeof(int8_t), .type = type}; + case ScalarType::Byte: + return {.value.i = scalar.to(), .size = sizeof(uint8_t), .type = type}; + case ScalarType::Bool: + return {.value.b = scalar.to(), .size = sizeof(bool), .type = type}; default: TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend."); } } MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar) { - MPSGraphTensorData *result = nullptr; + MPSGraphTensorData* result = nullptr; // Scalar pools are only supported on devices with unified memory if (mpsStream->device().hasUnifiedMemory) { scalar.buffer = getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size); - result = [[[MPSGraphTensorData alloc] initWithMTLBuffer: scalar.getMTLBuffer() - shape: @[@1] - dataType: getMPSScalarType(scalar.type)] autorelease]; + result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:scalar.getMTLBuffer() + shape:@[ @1 ] + dataType:getMPSScalarType(scalar.type)] autorelease]; } else { - MPSNDArrayDescriptor *tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type) shape:@[@1]]; - MPSNDArray *tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:tensorDesc] autorelease]; + MPSNDArrayDescriptor* tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type) + shape:@[ @1 ]]; + MPSNDArray* tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() + descriptor:tensorDesc] autorelease]; [tensorNDArray writeBytes:&scalar.value strideBytes:nil]; result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:tensorNDArray] autorelease]; } @@ -371,58 +383,50 @@ void resize_tensor(Tensor* output) { return mpsGraph; } -MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) { - return [mpsGraph placeholderWithShape:nil - dataType:dataType - name:nil]; +MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) { + return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil]; } -MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape) { - return [mpsGraph placeholderWithShape:mpsShape - dataType:dataType - name:nil]; +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) { + return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil]; } -MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor) { - return [mpsGraph placeholderWithShape:getMPSShape(tensor) - dataType:getMPSScalarType(tensor.scalar_type()) - name:nil]; +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) { + return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil]; } -MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) { - return [mpsGraph placeholderWithShape:@[@1] - dataType:dataType - name:nil]; +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) { + return [mpsGraph placeholderWithShape:@[ @1 ] dataType:dataType name:nil]; } -MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar) { - return [mpsGraph placeholderWithShape:@[@1] - dataType:getMPSScalarType(scalar.type()) - name:nil]; +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar) { + return [mpsGraph placeholderWithShape:@[ @1 ] dataType:getMPSScalarType(scalar.type()) name:nil]; } // this is meant to suppress the availability warning on castTensor // we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too -MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { if ([tensor dataType] == toType) { return tensor; } return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"]; } -MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) { +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType) { return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"]; } -MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor) { +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) { TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!"); return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil] - dimension:2 withDimension:1 name: nil]; + dimension:2 + withDimension:1 + name:nil]; } string get_mem_format_string(c10::MemoryFormat memory_format) { string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -439,11 +443,12 @@ string get_mem_format_string(c10::MemoryFormat memory_format) { MPSGraphCache* MPSGraphCache::_instance_cache = nullptr; class MPSGraphCacheCallback : public IMpsAllocatorCallback { -public: - MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { } + public: + MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) {} + + void executeMPSAllocatorCallback(void* ptr, EventType event) override {} - void executeMPSAllocatorCallback(void* ptr, EventType event) override { } -private: + private: MPSGraphCache* graph_cache; }; diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index bd0ff1a99dc018..d18f0b87a12c8e 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -22,18 +22,17 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = getCurrentMPSStream(); - bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) || - self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); Tensor output = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); @autoreleasepool { string key = "relu" + getTensorsStringKey({self}); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -41,44 +40,40 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); // passing selector of reLUWithTensor on the mpsGraph object - MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nil, false); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } return output; } -Tensor & relu_mps_(Tensor & self) { +Tensor& relu_mps_(Tensor& self) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; // Inplace relu - Tensor &output = self; - bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) || - self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor& output = self; + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); Tensor out; if (executeGatherOp) { out = at::empty_like(self, MemoryFormat::Contiguous); @@ -91,10 +86,9 @@ Tensor relu_mps(const Tensor& self) { @autoreleasepool { string key = "relu_" + getTensorsStringKey({self}); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -102,28 +96,25 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); // passing selector of reLUWithTensor on the mpsGraph object - MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, executeGatherOp ? out : output, nil, false); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); if (executeGatherOp) { @@ -134,25 +125,22 @@ Tensor relu_mps(const Tensor& self) { return output; } -TORCH_IMPL_FUNC(leaky_relu_out_mps) ( - const Tensor& self, const Scalar& negative_slope, const Tensor& output) { +TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_slope, const Tensor& output) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(output.is_mps()); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSStream *stream = getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to()); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -161,7 +149,7 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* negSlopeTensor = [mpsGraph constantWithScalar:negative_slope.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(self)]; MPSGraphTensor* negSlopeMulXTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:negSlopeTensor @@ -182,33 +170,27 @@ Tensor relu_mps(const Tensor& self) { Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - } -TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) ( - const Tensor& grad_output, - const Tensor& self, - const Scalar& negative_slope, - bool self_is_result, - const Tensor& output ) { - +TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) +(const Tensor& grad_output, + const Tensor& self, + const Scalar& negative_slope, + bool self_is_result, + const Tensor& output) { using namespace mps; TORCH_CHECK(output.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; @@ -216,17 +198,16 @@ Tensor relu_mps(const Tensor& self) { MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSStream *stream = getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { + string key = + "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to()); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -236,11 +217,9 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); MPSGraphTensor* negSlopeTensor = [mpsGraph constantWithScalar:negative_slope.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSScalarType(self)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[@1] - dataType:getMPSScalarType(self)]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSScalarType(self)]; MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil]; @@ -258,7 +237,7 @@ Tensor relu_mps(const Tensor& self) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); @@ -271,19 +250,15 @@ Tensor relu_mps(const Tensor& self) { selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -TORCH_IMPL_FUNC(log_softmax_mps_out) ( - const Tensor &self, - const int64_t dim, - const bool half_to_float, - const Tensor &out) { +TORCH_IMPL_FUNC(log_softmax_mps_out) +(const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -299,9 +274,8 @@ Tensor relu_mps(const Tensor& self) { string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { CachedGraph* newCachedGraph = nil; @autoreleasepool { @@ -310,66 +284,52 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* maximumsTensor = [mpsGraph reductionMaximumWithTensor:inputTensor - axis:dim - name:nil]; + MPSGraphTensor* maximumsTensor = [mpsGraph reductionMaximumWithTensor:inputTensor axis:dim name:nil]; MPSGraphTensor* inputTensorSubMax = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:maximumsTensor name:nil]; - MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:inputTensorSubMax - name:nil]; + MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:inputTensorSubMax name:nil]; - MPSGraphTensor* exponentTensorReduced = [mpsGraph reductionSumWithTensor:exponentTensor - axis:dim - name:nil]; + MPSGraphTensor* exponentTensorReduced = [mpsGraph reductionSumWithTensor:exponentTensor axis:dim name:nil]; - MPSGraphTensor* logSumExpTensor = [mpsGraph logarithmWithTensor:exponentTensorReduced - name:nil]; + MPSGraphTensor* logSumExpTensor = [mpsGraph logarithmWithTensor:exponentTensorReduced name:nil]; MPSGraphTensor* outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensorSubMax - secondaryTensor:logSumExpTensor - name:nil]; + secondaryTensor:logSumExpTensor + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } -TORCH_IMPL_FUNC(log_softmax_backward_mps_out) ( - const Tensor& grad_output, - const Tensor& output, - int64_t dim, - ScalarType input_dtype, - const Tensor& out) { +TORCH_IMPL_FUNC(log_softmax_backward_mps_out) +(const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) { using namespace mps; if (output.numel() == 0) { return; } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* outputTensor_ = nil; MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; @@ -381,11 +341,10 @@ Tensor relu_mps(const Tensor& self) { @autoreleasepool { string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { CachedGraph* newCachedGraph = nil; @autoreleasepool { @@ -395,11 +354,8 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output)); MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); - MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:outputTensor - name:nil]; - MPSGraphTensor* sumTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor - axis:dim - name:nil]; + MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:outputTensor name:nil]; + MPSGraphTensor* sumTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:dim name:nil]; MPSGraphTensor* multiplicationTensor = [mpsGraph multiplicationWithPrimaryTensor:expTensor secondaryTensor:sumTensor name:nil]; @@ -413,10 +369,10 @@ Tensor relu_mps(const Tensor& self) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); Placeholder resultPlaceholder = Placeholder(cachedGraph->gradInputTensor_, out); @@ -426,13 +382,11 @@ Tensor relu_mps(const Tensor& self) { outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } std::tuple log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) { @@ -450,62 +404,52 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = getCurrentMPSStream(); - bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) || - self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); @autoreleasepool { - string key = "log_sigmoid_forward_out:" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:inputTensor.dataType]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType]; MPSGraphTensor* minTensor = [mpsGraph minimumWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil]; - MPSGraphTensor* absInputTensor = [mpsGraph absoluteWithTensor:inputTensor - name:nil]; - MPSGraphTensor* negAbsInputTensor = [mpsGraph negativeWithTensor:absInputTensor - name:nil]; - MPSGraphTensor* expNegAbsInputTensor = [mpsGraph exponentWithTensor:negAbsInputTensor - name:nil]; + MPSGraphTensor* absInputTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; + MPSGraphTensor* negAbsInputTensor = [mpsGraph negativeWithTensor:absInputTensor name:nil]; + MPSGraphTensor* expNegAbsInputTensor = [mpsGraph exponentWithTensor:negAbsInputTensor name:nil]; MPSGraphTensor* outputTensor = at::native::mps::log1p(mpsGraph, expNegAbsInputTensor); - outputTensor = [mpsGraph subtractionWithPrimaryTensor:minTensor - secondaryTensor:outputTensor - name:nil]; + outputTensor = [mpsGraph subtractionWithPrimaryTensor:minTensor secondaryTensor:outputTensor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -524,9 +468,9 @@ Tensor relu_mps(const Tensor& self) { } Tensor& log_sigmoid_backward_mps_out(const Tensor& grad_output, - const Tensor& self, - const Tensor& buffer, - Tensor& grad_input) { + const Tensor& self, + const Tensor& buffer, + Tensor& grad_input) { // NOTE: buffer is only used by CPU dispatch, we just ignore it here using namespace mps; @@ -536,9 +480,8 @@ Tensor relu_mps(const Tensor& self) { grad_input.resize_as_(self); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; @@ -548,20 +491,18 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = getCurrentMPSStream(); - bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) || - self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); Tensor grad_input_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); @autoreleasepool { - string key = "log_sigmoid_backward_out:" + getTensorsStringKey({self, grad_output}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -569,15 +510,9 @@ Tensor relu_mps(const Tensor& self) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* negOneTensor = [mpsGraph constantWithScalar:-1.0 - shape:@[@1] - dataType:inputTensor.dataType]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* negOneTensor = [mpsGraph constantWithScalar:-1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; MPSGraphTensor* inputNegPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil]; @@ -589,21 +524,16 @@ Tensor relu_mps(const Tensor& self) { truePredicateTensor:oneTensor falsePredicateTensor:negOneTensor name:nil]; - MPSGraphTensor* absInputTensor = [mpsGraph absoluteWithTensor:inputTensor - name:nil]; - MPSGraphTensor* negAbsInputTensor = [mpsGraph negativeWithTensor:absInputTensor - name:nil]; - MPSGraphTensor* expNegAbsInputTensor = [mpsGraph exponentWithTensor:negAbsInputTensor - name:nil]; + MPSGraphTensor* absInputTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; + MPSGraphTensor* negAbsInputTensor = [mpsGraph negativeWithTensor:absInputTensor name:nil]; + MPSGraphTensor* expNegAbsInputTensor = [mpsGraph exponentWithTensor:negAbsInputTensor name:nil]; MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:expNegAbsInputTensor secondaryTensor:oneTensor name:nil]; outputTensor = [mpsGraph divisionWithPrimaryTensor:expNegAbsInputTensor secondaryTensor:outputTensor name:nil]; - outputTensor = [mpsGraph multiplicationWithPrimaryTensor:signTensor - secondaryTensor:outputTensor - name:nil]; + outputTensor = [mpsGraph multiplicationWithPrimaryTensor:signTensor secondaryTensor:outputTensor name:nil]; outputTensor = [mpsGraph subtractionWithPrimaryTensor:maxDerivativeTensor secondaryTensor:outputTensor name:nil]; @@ -617,12 +547,13 @@ Tensor relu_mps(const Tensor& self) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, executeGatherOp ? grad_input_ : grad_input, nil, false); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->gradInputTensor_, executeGatherOp ? grad_input_ : grad_input, nil, false); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ @@ -630,9 +561,8 @@ Tensor relu_mps(const Tensor& self) { selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -649,22 +579,18 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c return grad_input; } -TORCH_IMPL_FUNC(sigmoid_backward_out_mps)( - const Tensor& grad_output, - const Tensor& output, - const Tensor& grad_input) { +TORCH_IMPL_FUNC(sigmoid_backward_out_mps)(const Tensor& grad_output, const Tensor& output, const Tensor& grad_input) { using namespace mps; TORCH_CHECK(grad_input.is_mps()); if (grad_output.numel() == 0) { return; } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -673,11 +599,10 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c @autoreleasepool { string key = "sigmoid_backward_out_mps:" + getMPSTypeString(grad_output); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -687,14 +612,14 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* oneMinusSigmoidTensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor secondaryTensor:outputTensor name:nil]; MPSGraphTensor* timesTensor = [mpsGraph multiplicationWithPrimaryTensor:oneMinusSigmoidTensor - secondaryTensor:outputTensor - name:nil]; + secondaryTensor:outputTensor + name:nil]; MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor secondaryTensor:timesTensor name:nil]; @@ -705,12 +630,12 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ @@ -718,32 +643,25 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - } -TORCH_IMPL_FUNC(tanh_backward_out_mps)( - const Tensor& grad_output, - const Tensor& output, - const Tensor& grad_input) { +TORCH_IMPL_FUNC(tanh_backward_out_mps)(const Tensor& grad_output, const Tensor& output, const Tensor& grad_input) { using namespace mps; TORCH_CHECK(grad_input.is_mps()); if (grad_output.numel() == 0) { return; } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -752,11 +670,10 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c @autoreleasepool { string key = "tanh_backward_out_mps:" + getMPSTypeString(grad_output); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -766,10 +683,9 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* tanh2Tensor = [mpsGraph squareWithTensor:outputTensor - name:nil]; + MPSGraphTensor* tanh2Tensor = [mpsGraph squareWithTensor:outputTensor name:nil]; MPSGraphTensor* oneMinusTanh2Tensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor secondaryTensor:tanh2Tensor name:nil]; @@ -783,12 +699,12 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ @@ -796,20 +712,15 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } } -TORCH_IMPL_FUNC(threshold_out_mps)( - const Tensor& self, - const Scalar& threshold, - const Scalar& value, - const Tensor& result) { +TORCH_IMPL_FUNC(threshold_out_mps) +(const Tensor& self, const Scalar& threshold, const Scalar& value, const Tensor& result) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(self.is_mps()); @@ -819,80 +730,72 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + - to_string(threshold.to()) + ":" + - to_string(value.to()); + string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to()) + ":" + + to_string(value.to()); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *thresholdTensor = [mpsGraph constantWithScalar: threshold.to() - shape: @[@1] - dataType: getMPSDataType(self)]; + MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() + shape:@[ @1 ] + dataType:getMPSDataType(self)]; - MPSGraphTensor *valueTensor = [mpsGraph constantWithScalar: value.to() - shape: @[@1] - dataType: getMPSDataType(self)]; + MPSGraphTensor* valueTensor = [mpsGraph constantWithScalar:value.to() + shape:@[ @1 ] + dataType:getMPSDataType(self)]; // x > threshold - MPSGraphTensor *predicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor - secondaryTensor: thresholdTensor - name: nil]; + MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor + secondaryTensor:thresholdTensor + name:nil]; // result = (self > threshold) ? self : value - MPSGraphTensor *outputTensor = [mpsGraph selectWithPredicateTensor: predicateTensor - truePredicateTensor: inputTensor - falsePredicateTensor: valueTensor - name: nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:inputTensor + falsePredicateTensor:valueTensor + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -TORCH_IMPL_FUNC(threshold_backward_out_mps)( - const Tensor& grad, - const Tensor& self, - const Scalar& threshold, - const Tensor& gradInput) { +TORCH_IMPL_FUNC(threshold_backward_out_mps) +(const Tensor& grad, const Tensor& self, const Scalar& threshold, const Tensor& gradInput) { using namespace mps; TORCH_CHECK(self.is_mps()); TORCH_CHECK(grad.is_mps()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -900,39 +803,37 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + - to_string(threshold.to()); + string key = + "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *gradTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* gradTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad); - MPSGraphTensor *thresholdTensor = [mpsGraph constantWithScalar: threshold.to() - shape: @[@1] - dataType: getMPSDataType(self)]; + MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() + shape:@[ @1 ] + dataType:getMPSDataType(self)]; - MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar: 0.0 - dataType: inputTensor.dataType]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; // x > threshold - MPSGraphTensor *predicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor - secondaryTensor: thresholdTensor - name: nil]; + MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor + secondaryTensor:thresholdTensor + name:nil]; // result = (self > threshold) ? grad : zeroTensor - MPSGraphTensor *gradInputTensor = [mpsGraph selectWithPredicateTensor: predicateTensor - truePredicateTensor: gradTensor - falsePredicateTensor: zeroTensor - name: nil]; + MPSGraphTensor* gradInputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:gradTensor + falsePredicateTensor:zeroTensor + name:nil]; newCachedGraph->gradTensor_ = gradTensor; newCachedGraph->inputTensor_ = inputTensor; @@ -940,7 +841,7 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); @@ -953,102 +854,65 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -MPSGraphTensor* normcdf (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) { - // (1.0f + erf(x*SQRT1_2)) * 0.5f * x; - auto dataType = [inputTensor dataType]; - const float SQRT1_2 = 0.707106781186547524400844362104849039f; - MPSGraphTensor *sqrt1_2 = [mpsGraph constantWithScalar: SQRT1_2 - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f - shape: @[@1] - dataType: dataType]; - - MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: sqrt1_2 - name : nil]; - erfTensor = [mpsGraph erfWithTensor: erfTensor name : nil]; - erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor - secondaryTensor: onef - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: halff - name : nil]; - - return erfTensor; +MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + // (1.0f + erf(x*SQRT1_2)) * 0.5f * x; + auto dataType = [inputTensor dataType]; + const float SQRT1_2 = 0.707106781186547524400844362104849039f; + MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* onef = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* halff = [mpsGraph constantWithScalar:0.5f shape:@[ @1 ] dataType:dataType]; + + MPSGraphTensor* erfTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:sqrt1_2 name:nil]; + erfTensor = [mpsGraph erfWithTensor:erfTensor name:nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor:erfTensor secondaryTensor:onef name:nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor:erfTensor secondaryTensor:halff name:nil]; + + return erfTensor; } -MPSGraphTensor* tanh (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) { - // 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3))) - auto dataType = [inputTensor dataType]; - constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; - constexpr float kKappa = 0.044715f; - MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: inputTensor - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: inputTensor - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: kappaf - name : nil]; - erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor - secondaryTensor: inputTensor - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: betaf - name : nil]; - erfTensor = [mpsGraph tanhWithTensor: erfTensor - name : nil]; - erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor - secondaryTensor: onef - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: halff - name : nil]; - - return erfTensor; +MPSGraphTensor* tanh(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + // 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3))) + auto dataType = [inputTensor dataType]; + constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; + constexpr float kKappa = 0.044715f; + MPSGraphTensor* betaf = [mpsGraph constantWithScalar:kBeta shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* kappaf = [mpsGraph constantWithScalar:kKappa shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* onef = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* halff = [mpsGraph constantWithScalar:0.5f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* erfTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:inputTensor + name:nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor:erfTensor secondaryTensor:inputTensor name:nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor:erfTensor secondaryTensor:kappaf name:nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor:erfTensor secondaryTensor:inputTensor name:nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor:erfTensor secondaryTensor:betaf name:nil]; + erfTensor = [mpsGraph tanhWithTensor:erfTensor name:nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor:erfTensor secondaryTensor:onef name:nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor:erfTensor secondaryTensor:halff name:nil]; + + return erfTensor; } -TORCH_IMPL_FUNC(gelu_out_mps) ( - const Tensor& self, c10::string_view approximate, const Tensor& output - ) { +TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate, const Tensor& output) { using namespace mps; TORCH_CHECK(output.is_mps()); TORCH_CHECK(c10::isFloatingType(self.scalar_type()), "GELU is only implemented for floating types"); // Empty output - if(output.numel() == 0) + if (output.numel() == 0) return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1057,69 +921,58 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c @autoreleasepool { string key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + c10::str(approximate); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, - getMPSDataType(self), - getMPSShape(self)); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); MPSGraphTensor* outputTensor = nil; - if(approximate == "tanh") { + if (approximate == "tanh") { outputTensor = tanh(mpsGraph, inputTensor); } else { outputTensor = normcdf(mpsGraph, inputTensor); } - outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor - secondaryTensor:inputTensor - name:nil]; + outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor secondaryTensor:inputTensor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - } -TORCH_IMPL_FUNC(gelu_backward_out_mps) ( - const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input - ) { +TORCH_IMPL_FUNC(gelu_backward_out_mps) +(const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input) { using namespace mps; // Empty output - if(grad_input.numel() == 0) + if (grad_input.numel() == 0) return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1128,126 +981,80 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c @autoreleasepool { string key = "gelu_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + c10::str(approximate); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { auto dataType = getMPSDataType(self); MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* gradTensor = mpsGraphRankedPlaceHolder(mpsGraph, - getMPSDataType(grad), - getMPSShape(grad)); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, - dataType, - getMPSShape(self)); + MPSGraphTensor* gradTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad), getMPSShape(grad)); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(self)); MPSGraphTensor* outputTensor = nil; - if(approximate == "tanh") { + if (approximate == "tanh") { constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * (0.5f); constexpr float kKappa = 0.044715f; - MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *threef = [mpsGraph constantWithScalar: 3.0f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor* x_sq = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: inputTensor - name: nil]; - MPSGraphTensor *x_cube = [mpsGraph multiplicationWithPrimaryTensor: x_sq - secondaryTensor: inputTensor - name: nil]; - MPSGraphTensor *inner = [mpsGraph multiplicationWithPrimaryTensor: kappaf - secondaryTensor: x_cube - name: nil]; - inner = [mpsGraph additionWithPrimaryTensor: inner - secondaryTensor: inputTensor - name: nil]; - inner = [mpsGraph multiplicationWithPrimaryTensor: betaf - secondaryTensor: inner - name: nil]; - MPSGraphTensor *tanhInner = [mpsGraph tanhWithTensor: inner - name: nil]; - MPSGraphTensor *left = [mpsGraph multiplicationWithPrimaryTensor: halff - secondaryTensor: inputTensor - name: nil]; - MPSGraphTensor *right = [mpsGraph additionWithPrimaryTensor: onef - secondaryTensor: tanhInner - name: nil]; - MPSGraphTensor *left_derivative = [mpsGraph multiplicationWithPrimaryTensor: halff - secondaryTensor: right - name: nil]; - MPSGraphTensor *tanh_derivative = [mpsGraph multiplicationWithPrimaryTensor: tanhInner - secondaryTensor: tanhInner - name: nil]; - tanh_derivative = [mpsGraph subtractionWithPrimaryTensor: onef - secondaryTensor: tanh_derivative - name: nil]; - MPSGraphTensor *inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: threef - secondaryTensor: kappaf - name: nil]; - inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: inner_derivative - secondaryTensor: x_sq - name: nil]; - inner_derivative = [mpsGraph additionWithPrimaryTensor: inner_derivative - secondaryTensor: onef - name: nil]; - inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: betaf - secondaryTensor: inner_derivative - name: nil]; - MPSGraphTensor *right_derivative = [mpsGraph multiplicationWithPrimaryTensor: left - secondaryTensor: tanh_derivative - name: nil]; - right_derivative = [mpsGraph multiplicationWithPrimaryTensor: right_derivative - secondaryTensor: inner_derivative - name: nil]; - outputTensor = [mpsGraph additionWithPrimaryTensor: left_derivative - secondaryTensor: right_derivative - name: nil]; - outputTensor = [mpsGraph multiplicationWithPrimaryTensor: gradTensor - secondaryTensor: outputTensor - name: nil]; + MPSGraphTensor* betaf = [mpsGraph constantWithScalar:kBeta shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* kappaf = [mpsGraph constantWithScalar:kKappa shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* halff = [mpsGraph constantWithScalar:0.5f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* onef = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* threef = [mpsGraph constantWithScalar:3.0f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* x_sq = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:inputTensor + name:nil]; + MPSGraphTensor* x_cube = [mpsGraph multiplicationWithPrimaryTensor:x_sq + secondaryTensor:inputTensor + name:nil]; + MPSGraphTensor* inner = [mpsGraph multiplicationWithPrimaryTensor:kappaf secondaryTensor:x_cube name:nil]; + inner = [mpsGraph additionWithPrimaryTensor:inner secondaryTensor:inputTensor name:nil]; + inner = [mpsGraph multiplicationWithPrimaryTensor:betaf secondaryTensor:inner name:nil]; + MPSGraphTensor* tanhInner = [mpsGraph tanhWithTensor:inner name:nil]; + MPSGraphTensor* left = [mpsGraph multiplicationWithPrimaryTensor:halff + secondaryTensor:inputTensor + name:nil]; + MPSGraphTensor* right = [mpsGraph additionWithPrimaryTensor:onef secondaryTensor:tanhInner name:nil]; + MPSGraphTensor* left_derivative = [mpsGraph multiplicationWithPrimaryTensor:halff + secondaryTensor:right + name:nil]; + MPSGraphTensor* tanh_derivative = [mpsGraph multiplicationWithPrimaryTensor:tanhInner + secondaryTensor:tanhInner + name:nil]; + tanh_derivative = [mpsGraph subtractionWithPrimaryTensor:onef secondaryTensor:tanh_derivative name:nil]; + MPSGraphTensor* inner_derivative = [mpsGraph multiplicationWithPrimaryTensor:threef + secondaryTensor:kappaf + name:nil]; + inner_derivative = [mpsGraph multiplicationWithPrimaryTensor:inner_derivative + secondaryTensor:x_sq + name:nil]; + inner_derivative = [mpsGraph additionWithPrimaryTensor:inner_derivative secondaryTensor:onef name:nil]; + inner_derivative = [mpsGraph multiplicationWithPrimaryTensor:betaf + secondaryTensor:inner_derivative + name:nil]; + MPSGraphTensor* right_derivative = [mpsGraph multiplicationWithPrimaryTensor:left + secondaryTensor:tanh_derivative + name:nil]; + right_derivative = [mpsGraph multiplicationWithPrimaryTensor:right_derivative + secondaryTensor:inner_derivative + name:nil]; + outputTensor = [mpsGraph additionWithPrimaryTensor:left_derivative + secondaryTensor:right_derivative + name:nil]; + outputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor secondaryTensor:outputTensor name:nil]; } else { constexpr float kBeta = M_2_SQRTPI * M_SQRT1_2 * (0.5); - MPSGraphTensor *halff = [mpsGraph constantWithScalar: -0.5f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta - shape: @[@1] - dataType: dataType]; + MPSGraphTensor* halff = [mpsGraph constantWithScalar:-0.5f shape:@[ @1 ] dataType:dataType]; + MPSGraphTensor* betaf = [mpsGraph constantWithScalar:kBeta shape:@[ @1 ] dataType:dataType]; MPSGraphTensor* cdf = normcdf(mpsGraph, inputTensor); - MPSGraphTensor *pdfMul = [mpsGraph squareWithTensor: inputTensor - name: nil]; - pdfMul = [mpsGraph multiplicationWithPrimaryTensor: pdfMul - secondaryTensor: halff - name: nil]; - pdfMul = [mpsGraph exponentWithTensor: pdfMul - name: nil]; - MPSGraphTensor* pdf = [mpsGraph multiplicationWithPrimaryTensor: pdfMul - secondaryTensor: betaf - name: nil]; - pdf = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: pdf - name: nil]; - pdf = [mpsGraph additionWithPrimaryTensor: pdf - secondaryTensor: cdf - name: nil]; - outputTensor = [mpsGraph multiplicationWithPrimaryTensor: gradTensor - secondaryTensor: pdf - name: nil]; + MPSGraphTensor* pdfMul = [mpsGraph squareWithTensor:inputTensor name:nil]; + pdfMul = [mpsGraph multiplicationWithPrimaryTensor:pdfMul secondaryTensor:halff name:nil]; + pdfMul = [mpsGraph exponentWithTensor:pdfMul name:nil]; + MPSGraphTensor* pdf = [mpsGraph multiplicationWithPrimaryTensor:pdfMul secondaryTensor:betaf name:nil]; + pdf = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:pdf name:nil]; + pdf = [mpsGraph additionWithPrimaryTensor:pdf secondaryTensor:cdf name:nil]; + outputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor secondaryTensor:pdf name:nil]; } newCachedGraph->gradTensor_ = gradTensor; @@ -1256,11 +1063,11 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder gradPlaceholder = Placeholder(cachedGraph->gradTensor_, grad); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder gradPlaceholder = Placeholder(cachedGraph->gradTensor_, grad); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); // Create dictionary of inputs and outputs @@ -1269,24 +1076,18 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - - } -void elu_variants_out_mps ( - const Tensor& self, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - const Tensor& result, - string func_name) { - +void elu_variants_out_mps(const Tensor& self, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + const Tensor& result, + string func_name) { using namespace mps; auto resultMemFormat = result.suggest_memory_format(); bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat)); @@ -1296,15 +1097,14 @@ void elu_variants_out_mps ( } // Empty output - if(result.numel() == 0) { + if (result.numel() == 0) { return; } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1312,48 +1112,40 @@ void elu_variants_out_mps ( MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + getTensorsStringKey({self}) + ":" + - to_string(alpha.to()) + ":" + - to_string(scale.to()) + ":" + - to_string(input_scale.to()); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to()) + ":" + + to_string(scale.to()) + ":" + to_string(input_scale.to()); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); // scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(self)]; MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(self)]; MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[@1] - dataType:getMPSDataType(self)]; + MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:inputScaleTensor name:nil]; - MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:scaledInputTensor - name:nil]; + MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; MPSGraphTensor* exponentMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:exponentTensor secondaryTensor:unitTensor name:nil]; @@ -1376,20 +1168,19 @@ void elu_variants_out_mps ( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); if (out.has_storage()) { @@ -1399,44 +1190,38 @@ void elu_variants_out_mps ( } // scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) -TORCH_IMPL_FUNC(elu_out_mps) ( - const Tensor& self, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - const Tensor& result) { - +TORCH_IMPL_FUNC(elu_out_mps) +(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result) { elu_variants_out_mps(self, alpha, scale, input_scale, result, "elu_out_mps"); } -TORCH_IMPL_FUNC(elu_backward_out_mps) ( - const Tensor& grad_output, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - bool is_result, - const Tensor& self_or_result, - const Tensor& grad_input -) { +TORCH_IMPL_FUNC(elu_backward_out_mps) +(const Tensor& grad_output, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result, + const Tensor& self_or_result, + const Tensor& grad_input) { using namespace mps; auto gradMemFormat = grad_input.suggest_memory_format(); - bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && grad_input.is_contiguous(gradMemFormat)); + bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && + grad_input.is_contiguous(gradMemFormat)); Tensor out; if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) { out = at::empty_like(grad_input, MemoryFormat::Contiguous); } // Empty output - if(grad_input.numel() == 0) { + if (grad_input.numel() == 0) { return; } - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *selfOrResultTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* selfOrResultTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1445,16 +1230,13 @@ void elu_variants_out_mps ( @autoreleasepool { string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + - to_string(alpha.to()) + ":" + - to_string(scale.to()) + ":" + - to_string(input_scale.to()) + ":" + - to_string(is_result); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + to_string(alpha.to()) + ":" + to_string(scale.to()) + ":" + + to_string(input_scale.to()) + ":" + to_string(is_result); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -1464,33 +1246,31 @@ void elu_variants_out_mps ( MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result); MPSGraphTensor* lessThanZeroGradTensor = nil; - if(is_result) { + if (is_result) { MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[@1] - dataType:getMPSDataType(grad_output)]; + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor secondaryTensor:alphaTensor name:nil]; auto constMul = scale.to() * input_scale.to(); MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:resultPlusAlphaTensor secondaryTensor:constMulTensor name:nil]; - } - else { + } else { MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[@1] - dataType:getMPSDataType(grad_output)]; + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor secondaryTensor:inputScaleTensor name:nil]; - MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor - name:nil]; + MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; auto constMul = scale.to() * input_scale.to() * alpha.to(); MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:expTensor secondaryTensor:constMulTensor @@ -1498,10 +1278,10 @@ void elu_variants_out_mps ( } MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor secondaryTensor:zeroTensor @@ -1520,21 +1300,22 @@ void elu_variants_out_mps ( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp); - Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->selfOrResultTensor_, self_or_result, nil, executeGatherOp); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false); + Placeholder selfOrResultPlaceholder = + Placeholder(cachedGraph->selfOrResultTensor_, self_or_result, nil, executeGatherOp); + Placeholder gradInputPlaceholder = + Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), selfOrResultPlaceholder.getMPSGraphTensor() : selfOrResultPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); if (out.has_storage()) { @@ -1543,14 +1324,12 @@ void elu_variants_out_mps ( } } -TORCH_IMPL_FUNC(glu_out_mps) ( - const Tensor& self, const int64_t dim, const Tensor& output - ) { +TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor& output) { using namespace mps; TORCH_CHECK(output.is_mps()); // Empty output - if(output.numel() == 0) + if (output.numel() == 0) return; // this can't pass anyway because a 0-dimensional tensor has "size" 1, which @@ -1558,14 +1337,12 @@ void elu_variants_out_mps ( TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); auto wrap_dim = maybe_wrap_dim(dim, self.dim()); const int64_t nIn = self.size(wrap_dim); - TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", - wrap_dim, " is size ", nIn); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", wrap_dim, " is size ", nIn); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1573,62 +1350,54 @@ void elu_variants_out_mps ( MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim);; - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim); + ; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, - getMPSDataType(self), - getMPSShape(self)); - NSArray * outputTensorsArray = [mpsGraph splitTensor:inputTensor - numSplits:2 - axis:wrap_dim - name:nil]; + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); + NSArray* outputTensorsArray = [mpsGraph splitTensor:inputTensor + numSplits:2 + axis:wrap_dim + name:nil]; MPSGraphTensor* firstHalf = outputTensorsArray[0]; - MPSGraphTensor* secondHalf = [mpsGraph sigmoidWithTensor:outputTensorsArray[1] - name:nil]; + MPSGraphTensor* secondHalf = [mpsGraph sigmoidWithTensor:outputTensorsArray[1] name:nil]; MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:firstHalf - secondaryTensor:secondHalf - name:nil]; + secondaryTensor:secondHalf + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } } -Tensor& glu_backward_mps_out ( - const Tensor& grad_output, const Tensor& self, const int64_t dim, Tensor& grad_input - ) { +Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, const int64_t dim, Tensor& grad_input) { using namespace mps; // Empty output - if(grad_input.numel() == 0) + if (grad_input.numel() == 0) return grad_input; // this can't pass anyway because a 0-dimensional tensor has "size" 1, which @@ -1636,15 +1405,13 @@ void elu_variants_out_mps ( TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); auto wrap_dim = maybe_wrap_dim(dim, self.dim()); const int64_t nIn = self.size(wrap_dim); - TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", - wrap_dim, " is size ", nIn); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", wrap_dim, " is size ", nIn); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1653,67 +1420,60 @@ void elu_variants_out_mps ( @autoreleasepool { string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, - getMPSDataType(self), - getMPSShape(self)); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, - getMPSDataType(grad_output), - getMPSShape(grad_output)); - NSArray * inputTensorsArray = [mpsGraph splitTensor:inputTensor - numSplits:2 - axis:wrap_dim - name:nil]; + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); + MPSGraphTensor* gradOutputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_output), getMPSShape(grad_output)); + NSArray* inputTensorsArray = [mpsGraph splitTensor:inputTensor + numSplits:2 + axis:wrap_dim + name:nil]; // first half - MPSGraphTensor* sigmoidOutputTensor = [mpsGraph sigmoidWithTensor:inputTensorsArray[1] - name:nil]; - MPSGraphTensor* firstHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : sigmoidOutputTensor - secondaryTensor : gradOutputTensor - name : nil]; + MPSGraphTensor* sigmoidOutputTensor = [mpsGraph sigmoidWithTensor:inputTensorsArray[1] name:nil]; + MPSGraphTensor* firstHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:sigmoidOutputTensor + secondaryTensor:gradOutputTensor + name:nil]; // second half - MPSGraphTensor* one_val = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* secondHalfOutputTensor = [mpsGraph subtractionWithPrimaryTensor : one_val - secondaryTensor : sigmoidOutputTensor - name : nil]; - secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : secondHalfOutputTensor - secondaryTensor : sigmoidOutputTensor - name : nil]; - secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : secondHalfOutputTensor - secondaryTensor : inputTensorsArray[0] - name : nil]; - secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor : secondHalfOutputTensor - secondaryTensor : gradOutputTensor - name : nil]; - - MPSGraphTensor* outputTensor = [mpsGraph concatTensor : firstHalfOutputTensor - withTensor : secondHalfOutputTensor - dimension : wrap_dim - name : nil]; + MPSGraphTensor* one_val = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + + MPSGraphTensor* secondHalfOutputTensor = [mpsGraph subtractionWithPrimaryTensor:one_val + secondaryTensor:sigmoidOutputTensor + name:nil]; + secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:secondHalfOutputTensor + secondaryTensor:sigmoidOutputTensor + name:nil]; + secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:secondHalfOutputTensor + secondaryTensor:inputTensorsArray[0] + name:nil]; + secondHalfOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:secondHalfOutputTensor + secondaryTensor:gradOutputTensor + name:nil]; + + MPSGraphTensor* outputTensor = [mpsGraph concatTensor:firstHalfOutputTensor + withTensor:secondHalfOutputTensor + dimension:wrap_dim + name:nil]; newCachedGraph->gradInputTensor_ = outputTensor; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->gradOutputTensor_ = gradOutputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); // Create dictionary of inputs and outputs @@ -1726,285 +1486,246 @@ void elu_variants_out_mps ( gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData(), }; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } return grad_input; - } -Tensor glu_backward_mps (const Tensor& grad_output, - const Tensor& self, - const int64_t dim) { - - Tensor grad_input = at::native::empty_mps( - self.sizes(), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); +Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int64_t dim) { + Tensor grad_input = + at::native::empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); grad_input = glu_backward_mps_out(grad_output, self, dim, grad_input); return grad_input; } +TORCH_IMPL_FUNC(softplus_out_mps) +(const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) { + using namespace mps; + TORCH_CHECK(self.is_mps()); + // Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * + // \log(1 + \exp(\beta * x))` element-wise. + // For numerical stability the implementation reverts to the linear function + // when :math:`input \times \beta > threshold`. -TORCH_IMPL_FUNC(softplus_out_mps) ( - const Tensor& self, - const Scalar& beta, - const Scalar& threshold, - const Tensor& result) { - using namespace mps; - TORCH_CHECK(self.is_mps()); - // Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * - // \log(1 + \exp(\beta * x))` element-wise. - // For numerical stability the implementation reverts to the linear function - // when :math:`input \times \beta > threshold`. - - // Empty output - if(result.numel() == 0) - return; - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *betaTensor_ = nil; - MPSGraphTensor *thresholdTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; + // Empty output + if (result.numel() == 0) + return; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* betaTensor_ = nil; + MPSGraphTensor* thresholdTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; - MPSStream* stream = getCurrentMPSStream(); - MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); - MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + - std::to_string(beta.to()) + ":" + std::to_string(threshold.to()); + MPSStream* stream = getCurrentMPSStream(); + MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); + MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + @autoreleasepool { + string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to()) + ":" + + std::to_string(threshold.to()); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); + MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); - MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); + MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); - MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor - name:nil]; + MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; - MPSGraphTensor* reciprocalBetaTensor = [mpsGraph reciprocalWithTensor:betaTensor - name:nil]; - MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:betaTensor - name:nil]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor - secondaryTensor:thresholdTensor - name:nil]; - MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:bxTensor - name:nil]; - MPSGraphTensor* log1pTensor = at::native::mps::log1p(mpsGraph, expTensor); - MPSGraphTensor* softplusTensor = [mpsGraph multiplicationWithPrimaryTensor:log1pTensor - secondaryTensor:reciprocalBetaTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:reluTensor - falsePredicateTensor:softplusTensor - name:nil]; + MPSGraphTensor* reciprocalBetaTensor = [mpsGraph reciprocalWithTensor:betaTensor name:nil]; + MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:betaTensor + name:nil]; + MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor + secondaryTensor:thresholdTensor + name:nil]; + MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:bxTensor name:nil]; + MPSGraphTensor* log1pTensor = at::native::mps::log1p(mpsGraph, expTensor); + MPSGraphTensor* softplusTensor = [mpsGraph multiplicationWithPrimaryTensor:log1pTensor + secondaryTensor:reciprocalBetaTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:reluTensor + falsePredicateTensor:softplusTensor + name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->betaTensor_ = betaTensor; - newCachedGraph->thresholdTensor_ = thresholdTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->betaTensor_ = betaTensor; + newCachedGraph->thresholdTensor_ = thresholdTensor; + newCachedGraph->outputTensor_ = outputTensor; } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), - cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } -} + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); -TORCH_IMPL_FUNC(softplus_backward_out_mps) ( - const Tensor& grad_output, - const Tensor& self, - const Scalar& beta, - const Scalar& threshold, - const Tensor& grad_input -) { - using namespace mps; - TORCH_CHECK(self.is_mps()); + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), + cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), + }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} - // Empty output - if(grad_input.numel() == 0) - return; +TORCH_IMPL_FUNC(softplus_backward_out_mps) +(const Tensor& grad_output, const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& grad_input) { + using namespace mps; + TORCH_CHECK(self.is_mps()); - MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); - MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); + // Empty output + if (grad_input.numel() == 0) + return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *betaTensor_ = nil; - MPSGraphTensor *thresholdTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; + MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); + MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* betaTensor_ = nil; + MPSGraphTensor* thresholdTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; - MPSStream* stream = getCurrentMPSStream(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" + - std::to_string(beta.to()) + ":" + std::to_string(threshold.to()); + MPSStream* stream = getCurrentMPSStream(); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + @autoreleasepool { + string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" + + std::to_string(beta.to()) + ":" + std::to_string(threshold.to()); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); + MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); - MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); + MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:betaTensor - name:nil]; - MPSGraphTensor* expBxTensor = [mpsGraph exponentWithTensor:bxTensor - name:nil]; - MPSGraphTensor* unitExpBxTensor = [mpsGraph additionWithPrimaryTensor:expBxTensor - secondaryTensor:unitTensor - name:nil]; - MPSGraphTensor* rTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor - secondaryTensor:expBxTensor - name:nil]; - rTensor = [mpsGraph divisionWithPrimaryTensor:rTensor - secondaryTensor:unitExpBxTensor - name:nil]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor - secondaryTensor:thresholdTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:gradOutputTensor - falsePredicateTensor:rTensor - name:nil]; + MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:betaTensor + name:nil]; + MPSGraphTensor* expBxTensor = [mpsGraph exponentWithTensor:bxTensor name:nil]; + MPSGraphTensor* unitExpBxTensor = [mpsGraph additionWithPrimaryTensor:expBxTensor + secondaryTensor:unitTensor + name:nil]; + MPSGraphTensor* rTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor + secondaryTensor:expBxTensor + name:nil]; + rTensor = [mpsGraph divisionWithPrimaryTensor:rTensor secondaryTensor:unitExpBxTensor name:nil]; + MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor + secondaryTensor:thresholdTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:gradOutputTensor + falsePredicateTensor:rTensor + name:nil]; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->betaTensor_ = betaTensor; - newCachedGraph->thresholdTensor_ = thresholdTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->betaTensor_ = betaTensor; + newCachedGraph->thresholdTensor_ = thresholdTensor; + newCachedGraph->outputTensor_ = outputTensor; } - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), - cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), - }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } -} + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), + cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), + }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { - using namespace mps; - - Tensor result = at::empty_like(self, self.suggest_memory_format()); - TORCH_INTERNAL_ASSERT(weight_.defined()); + using namespace mps; - if (result.numel() == 0){ - return result; - } + Tensor result = at::empty_like(self, self.suggest_memory_format()); + TORCH_INTERNAL_ASSERT(weight_.defined()); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *weightTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; + if (result.numel() == 0) { + return result; + } - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* weightTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; - MPSStream* stream = getCurrentMPSStream(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = "prelu_mps:" + getTensorsStringKey({self, weight_}); + MPSStream* stream = getCurrentMPSStream(); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + @autoreleasepool { + string key = "prelu_mps:" + getTensorsStringKey({self, weight_}); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_); + MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_); - MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor *reluTensor = [mpsGraph reLUWithTensor:inputTensor - name:nil]; - MPSGraphTensor *predicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor - secondaryTensor: zeroTensor - name: nil]; - MPSGraphTensor *weightedTensor = [mpsGraph selectWithPredicateTensor: predicateTensor - truePredicateTensor: inputTensor - falsePredicateTensor: zeroTensor - name: nil]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; + MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:zeroTensor + name:nil]; + MPSGraphTensor* weightedTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:inputTensor + falsePredicateTensor:zeroTensor + name:nil]; weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:weightedTensor secondaryTensor:weightTensor name:nil]; - MPSGraphTensor *outputTensor = [mpsGraph additionWithPrimaryTensor:reluTensor + MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:reluTensor secondaryTensor:weightedTensor name:nil]; @@ -2013,135 +1734,126 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + }); + cachedGraph = static_cast(tmpCachedGraph); } + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } return result; } std::tuple prelu_backward_mps(const Tensor& grad_output, const Tensor& self, const Tensor& weight_) { - using namespace mps; - - Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); - Tensor weight_grad = at::empty_like(self, at::MemoryFormat::Contiguous); - if (grad_output.numel() == 0) { - return std::tuple{grad_input, weight_grad}; - } + using namespace mps; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *weightTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - MPSGraphTensor *weightedGradTensor_ = nil; - }; + Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); + Tensor weight_grad = at::empty_like(self, at::MemoryFormat::Contiguous); + if (grad_output.numel() == 0) { + return std::tuple{grad_input, weight_grad}; + } - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* weightTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* weightedGradTensor_ = nil; + }; - MPSStream* stream = getCurrentMPSStream(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_}); + MPSStream* stream = getCurrentMPSStream(); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + @autoreleasepool { + string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_}); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_); + MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_); - MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar: 0.0 - shape:@[@1] - dataType: inputTensor.dataType]; - MPSGraphTensor* weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:weightTensor - secondaryTensor:gradOutputTensor - name:nil]; - MPSGraphTensor* inputGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:gradOutputTensor - name:nil]; - MPSGraphTensor *predicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor - secondaryTensor: zeroTensor - name: nil]; - MPSGraphTensor *outputTensor = [mpsGraph selectWithPredicateTensor: predicateTensor - truePredicateTensor: gradOutputTensor - falsePredicateTensor: weightedGradOutputTensor - name: nil]; - MPSGraphTensor *weightedGradTensor = [mpsGraph selectWithPredicateTensor: predicateTensor - truePredicateTensor: zeroTensor - falsePredicateTensor: inputGradOutputTensor - name: nil]; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->weightTensor_ = weightTensor; - newCachedGraph->outputTensor_ = outputTensor; - newCachedGraph->weightedGradTensor_ = weightedGradTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); - Placeholder weightedGradPlaceholder = Placeholder(cachedGraph->weightedGradTensor_, weight_grad); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData(), - weightedGradPlaceholder.getMPSGraphTensor() : weightedGradPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:weightTensor + secondaryTensor:gradOutputTensor + name:nil]; + MPSGraphTensor* inputGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:gradOutputTensor + name:nil]; + MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor + secondaryTensor:zeroTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:gradOutputTensor + falsePredicateTensor:weightedGradOutputTensor + name:nil]; + MPSGraphTensor* weightedGradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:inputGradOutputTensor + name:nil]; + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->weightTensor_ = weightTensor; + newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->weightedGradTensor_ = weightedGradTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); } + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + Placeholder weightedGradPlaceholder = Placeholder(cachedGraph->weightedGradTensor_, weight_grad); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData(), + weightedGradPlaceholder.getMPSGraphTensor() : weightedGradPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } return std::tuple{grad_input, weight_grad}; } -TORCH_IMPL_FUNC(silu_out_mps) ( - const Tensor& self, - const Tensor& result) { - +TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); // Empty output - if(result.numel() == 0) + if (result.numel() == 0) return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -2151,25 +1863,20 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { @autoreleasepool { string key = "silu_out_mps:" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor - name:nil]; - MPSGraphTensor* expNegativeTensor = [mpsGraph exponentWithTensor:negativeInput - name:nil]; + MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor name:nil]; + MPSGraphTensor* expNegativeTensor = [mpsGraph exponentWithTensor:negativeInput name:nil]; MPSGraphTensor* expPlusOneTensor = [mpsGraph additionWithPrimaryTensor:expNegativeTensor secondaryTensor:unitTensor name:nil]; @@ -2182,44 +1889,36 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } -TORCH_IMPL_FUNC(silu_backward_out_mps) ( - const Tensor& grad_output, - const Tensor& self, - const Tensor& grad_input) { - +TORCH_IMPL_FUNC(silu_backward_out_mps)(const Tensor& grad_output, const Tensor& self, const Tensor& grad_input) { using namespace mps; TORCH_CHECK(grad_output.is_mps()); // Empty output - if(grad_input.numel() == 0) + if (grad_input.numel() == 0) return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -2229,31 +1928,27 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { @autoreleasepool { string key = "silu_out_backward_mps:" + getTensorsStringKey({grad_output}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor - name:nil]; - MPSGraphTensor* expNegativeTensor = [mpsGraph exponentWithTensor:negativeInput - name:nil]; + MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor name:nil]; + MPSGraphTensor* expNegativeTensor = [mpsGraph exponentWithTensor:negativeInput name:nil]; MPSGraphTensor* expPlusOneTensor = [mpsGraph additionWithPrimaryTensor:expNegativeTensor secondaryTensor:unitTensor name:nil]; - MPSGraphTensor* sigmoidTensor = [mpsGraph reciprocalWithTensor:expPlusOneTensor - name:nil]; + MPSGraphTensor* sigmoidTensor = [mpsGraph reciprocalWithTensor:expPlusOneTensor name:nil]; MPSGraphTensor* oneMinusSigmoid = [mpsGraph subtractionWithPrimaryTensor:unitTensor secondaryTensor:sigmoidTensor name:nil]; @@ -2276,7 +1971,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); @@ -2289,29 +1984,25 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } - -TORCH_IMPL_FUNC(hardsigmoid_out_mps) (const Tensor& self, const Tensor& result) { +TORCH_IMPL_FUNC(hardsigmoid_out_mps)(const Tensor& self, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); // Empty output - if(result.numel() == 0) + if (result.numel() == 0) return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -2321,26 +2012,19 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { @autoreleasepool { string key = "hardsigmoid_out_mps:" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0 - shape:@[@1] - dataType:getMPSDataType(self)]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; MPSGraphTensor* inputPlusThreeTensor = [mpsGraph additionWithPrimaryTensor:inputTensor secondaryTensor:threeTensor name:nil]; @@ -2349,49 +2033,42 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { minValueTensor:zeroTensor maxValueTensor:sixTensor name:nil]; - outputTensor = [mpsGraph divisionWithPrimaryTensor:outputTensor - secondaryTensor:sixTensor - name:nil]; + outputTensor = [mpsGraph divisionWithPrimaryTensor:outputTensor secondaryTensor:sixTensor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -TORCH_IMPL_FUNC(hardsigmoid_backward_out_mps) ( - const Tensor& grad_output, const Tensor& self, const Tensor& grad_input -) { +TORCH_IMPL_FUNC(hardsigmoid_backward_out_mps)(const Tensor& grad_output, const Tensor& self, const Tensor& grad_input) { using namespace mps; TORCH_CHECK(self.is_mps()); // Empty output - if(grad_input.numel() == 0) + if (grad_input.numel() == 0) return; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -2401,34 +2078,27 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { @autoreleasepool { string key = "hardsigmoid_backward_out_mps:" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* highTensor = [mpsGraph constantWithScalar:3.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:-3.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor* oneSixTensor = [mpsGraph constantWithScalar:1.0/6.0 - shape:@[@1] - dataType:getMPSDataType(self)]; - MPSGraphTensor *inputLessThanHighPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* highTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:-3.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; + MPSGraphTensor* oneSixTensor = [mpsGraph constantWithScalar:1.0 / 6.0 + shape:@[ @1 ] + dataType:getMPSDataType(self)]; + MPSGraphTensor* inputLessThanHighPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:highTensor name:nil]; - MPSGraphTensor *inputGreaterThanLowPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor + MPSGraphTensor* inputGreaterThanLowPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor secondaryTensor:lowTensor name:nil]; MPSGraphTensor* inIntervalTensor = [mpsGraph logicalANDWithPrimaryTensor:inputLessThanHighPredicateTensor @@ -2448,7 +2118,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); @@ -2461,9 +2131,8 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -2472,44 +2141,31 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { // ------------------------------------------------- // Hardtanh backward -Tensor hardtanh_backward_mps - (const Tensor& grad_output, - const Tensor& self, - const Scalar& min, - const Scalar& max) { - +Tensor hardtanh_backward_mps(const Tensor& grad_output, const Tensor& self, const Scalar& min, const Scalar& max) { Tensor grad_input = at::native::empty_mps( - grad_output.sizes(), - grad_output.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + grad_output.sizes(), grad_output.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); grad_input = hardtanh_backward_out_mps(grad_output, self, min, max, grad_input); return grad_input; } // Hardtanh backward -Tensor& hardtanh_backward_out_mps - (const Tensor& grad_output, - const Tensor& self, - const Scalar& min, - const Scalar& max, - Tensor& grad_input) { - +Tensor& hardtanh_backward_out_mps(const Tensor& grad_output, + const Tensor& self, + const Scalar& min, + const Scalar& max, + Tensor& grad_input) { using namespace mps; TORCH_CHECK(grad_output.is_mps()); // Empty output - if(grad_input.numel() == 0) + if (grad_input.numel() == 0) return grad_input; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -2517,15 +2173,13 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + - to_string(min.to()) + ":" + - to_string(max.to()); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to()) + + ":" + to_string(max.to()); - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -2536,23 +2190,23 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { // TODO: Compute gradient MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* minTensor = [mpsGraph constantWithScalar:min.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* maxTensor = [mpsGraph constantWithScalar:max.to() - shape:@[@1] + shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; MPSGraphTensor* greaterThanMaxPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor secondaryTensor:maxTensor name:nil]; MPSGraphTensor* lesserThanMinPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:minTensor - name:nil]; + secondaryTensor:minTensor + name:nil]; MPSGraphTensor* greaterThanMaxGradTensor = [mpsGraph selectWithPredicateTensor:greaterThanMaxPredicateTensor truePredicateTensor:zeroTensor falsePredicateTensor:unitTensor @@ -2574,7 +2228,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); @@ -2587,9 +2241,8 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -2620,76 +2273,57 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { string key = "hardswish_out_mps" + getTensorsStringKey({self}); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { - MPSCachedGraph* tmpCachedGraph = - cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { - CachedGraph* newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = - mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* zeroTensor = [mpsGraph - constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* threeTensor = [mpsGraph - constantWithScalar:3.0f - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* negativeThreeTensor = [mpsGraph - constantWithScalar:-3.0f - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* sixTensor = [mpsGraph - constantWithScalar:6.0f - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph - lessThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:negativeThreeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxPredicateTensor = - [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* inputPlusThreeTensor = - [mpsGraph additionWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* inputDivSixTensor = - [mpsGraph divisionWithPrimaryTensor:inputPlusThreeTensor - secondaryTensor:sixTensor - name:nil]; - - MPSGraphTensor* weightedTensor = - [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:inputDivSixTensor - name:nil]; - - MPSGraphTensor* tempTensor = - [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor - truePredicateTensor:weightedTensor - falsePredicateTensor:inputTensor - name:nil]; - - MPSGraphTensor* outputTensor = - [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:tempTensor - name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; + + MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; + + MPSGraphTensor* negativeThreeTensor = [mpsGraph constantWithScalar:-3.0f + shape:@[ @1 ] + dataType:getMPSDataType(self)]; + + MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; + + MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph lessThanOrEqualToWithPrimaryTensor:inputTensor + secondaryTensor:negativeThreeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* inputPlusThreeTensor = [mpsGraph additionWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* inputDivSixTensor = [mpsGraph divisionWithPrimaryTensor:inputPlusThreeTensor + secondaryTensor:sixTensor + name:nil]; + + MPSGraphTensor* weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:inputDivSixTensor + name:nil]; + + MPSGraphTensor* tempTensor = [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor + truePredicateTensor:weightedTensor + falsePredicateTensor:inputTensor + name:nil]; + + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:tempTensor + name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); @@ -2697,15 +2331,11 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : output, nil, false); // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : - selfPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : - outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); if (out.has_storage()) { @@ -2758,66 +2388,54 @@ Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) { MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* zeroTensor = [mpsGraph - constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* unitTensor = [mpsGraph - constantWithScalar:1.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* threeTensor = [mpsGraph - constantWithScalar:3.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* negativeThreeTensor = [mpsGraph - constantWithScalar:-3.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* halfTensor = [mpsGraph - constantWithScalar:0.5f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - - MPSGraphTensor* tempTensor = - [mpsGraph divisionWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* weightedTensor = - [mpsGraph additionWithPrimaryTensor:tempTensor - secondaryTensor:halfTensor - name:nil]; - - MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph - lessThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:negativeThreeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxPredicateTensor = - [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxGradTensor = - [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor - truePredicateTensor:weightedTensor - falsePredicateTensor:unitTensor - name:nil]; - - MPSGraphTensor* gradTensor = - [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:lessThanMaxGradTensor - name:nil]; - MPSGraphTensor* gradInputTensor = - [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; + + MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; + + MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; + + MPSGraphTensor* negativeThreeTensor = [mpsGraph constantWithScalar:-3.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; + + MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:0.5f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output)]; + + MPSGraphTensor* tempTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* weightedTensor = [mpsGraph additionWithPrimaryTensor:tempTensor + secondaryTensor:halfTensor + name:nil]; + + MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph lessThanOrEqualToWithPrimaryTensor:inputTensor + secondaryTensor:negativeThreeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxGradTensor = [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor + truePredicateTensor:weightedTensor + falsePredicateTensor:unitTensor + name:nil]; + + MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:lessThanMaxGradTensor + name:nil]; + MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor + secondaryTensor:gradOutputTensor + name:nil]; newCachedGraph->gradOutputTensor_ = gradOutputTensor; newCachedGraph->inputTensor_ = inputTensor; @@ -2837,9 +2455,8 @@ Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) { selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } diff --git a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm index d90545147e3941..2ca5b66c07e157 100644 --- a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm +++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm @@ -1,52 +1,54 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at::native { -void set_kernel_params - (int64_t isizeH, int64_t isizeW, - int64_t osizeH, int64_t osizeW, - int64_t &strideH, int64_t &strideW, - int64_t &kernel_sizeH, int64_t &kernel_sizeW, - bool check_avg_pooling = false) { - +void set_kernel_params(int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW, + int64_t& strideH, + int64_t& strideW, + int64_t& kernel_sizeH, + int64_t& kernel_sizeW, + bool check_avg_pooling = false) { TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW), "Adaptive pool MPS: Input height and width must both be greater than, " "or equal to, or lesser than output height and width") - if(isizeH >= osizeH) { + if (isizeH >= osizeH) { if (check_avg_pooling) { TORCH_CHECK((isizeH % osizeH == 0 && isizeW % osizeW == 0), - "Adaptive pool MPS: input sizes must be divisible by output sizes."); + "Adaptive pool MPS: input sizes must be divisible by output sizes."); } - strideH = (int64_t) (isizeH / osizeH); - strideW = (int64_t) (isizeW / osizeW); - kernel_sizeH = isizeH - (osizeH-1) * strideH; - kernel_sizeW = isizeW - (osizeW-1) * strideW; + strideH = (int64_t)(isizeH / osizeH); + strideW = (int64_t)(isizeW / osizeW); + kernel_sizeH = isizeH - (osizeH - 1) * strideH; + kernel_sizeW = isizeW - (osizeW - 1) * strideW; } else { if (check_avg_pooling) { TORCH_CHECK((osizeH % isizeH == 0 && osizeW % isizeW == 0), "Adaptive pool MPS: output sizes must be divisible by input sizes."); } - strideH = (int64_t) (osizeH / isizeH); - strideW = (int64_t) (osizeW / isizeW); - kernel_sizeH = osizeH - (isizeH-1) * strideH; - kernel_sizeW = osizeW - (isizeW-1) * strideW; + strideH = (int64_t)(osizeH / isizeH); + strideW = (int64_t)(osizeW / isizeW); + kernel_sizeH = osizeH - (isizeH - 1) * strideH; + kernel_sizeW = osizeW - (isizeW - 1) * strideW; } } // Adaptive average pooling -Tensor& adaptive_avg_pool2d_out_mps - (const Tensor& input, - IntArrayRef output_size, - Tensor& output) { - +Tensor& adaptive_avg_pool2d_out_mps(const Tensor& input, IntArrayRef output_size, Tensor& output) { for (int64_t i = 1; i < input.ndimension(); i++) { TORCH_CHECK(input.size(i) > 0, - "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, " - "but input has sizes ", input.sizes(), " with dimension ", i, " being empty"); + "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + "but input has sizes ", + input.sizes(), + " with dimension ", + i, + " being empty"); } int64_t isizeH = input.size(-2); @@ -57,45 +59,39 @@ int64_t strideH = 0, strideW = 0; int64_t kernel_sizeH = 0, kernel_sizeW = 0; - set_kernel_params(isizeH, isizeW, - osizeH, osizeW, - strideH, strideW, - kernel_sizeH, kernel_sizeW, true); - - if(isizeH >= osizeH) { - output = at::avg_pool2d(input, - IntArrayRef({kernel_sizeH, kernel_sizeW}), - IntArrayRef({strideH, strideW}), - IntArrayRef({0, 0}), - false, - true, - c10::nullopt); - } else { + set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true); + + if (isizeH >= osizeH) { + output = at::avg_pool2d(input, + IntArrayRef({kernel_sizeH, kernel_sizeW}), + IntArrayRef({strideH, strideW}), + IntArrayRef({0, 0}), + false, + true, + c10::nullopt); + } else { Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto input_sizes = input.sizes(); - std::vector phony_shape{input_sizes.begin(), input_sizes.end() -2}; + std::vector phony_shape{input_sizes.begin(), input_sizes.end() - 2}; phony_shape.push_back(output_size[0]); phony_shape.push_back(output_size[1]); phony_grad.resize_(IntArrayRef(phony_shape)); - output = at::avg_pool2d_backward(input, - phony_grad, - IntArrayRef({kernel_sizeH, kernel_sizeW}), - IntArrayRef({strideH, strideW}), - IntArrayRef({0, 0}), - false, - true, - c10::nullopt); + output = at::avg_pool2d_backward(input, + phony_grad, + IntArrayRef({kernel_sizeH, kernel_sizeW}), + IntArrayRef({strideH, strideW}), + IntArrayRef({0, 0}), + false, + true, + c10::nullopt); // Multiply output by kernel size - output = at::mul(output, kernel_sizeH*kernel_sizeW); + output = at::mul(output, kernel_sizeH * kernel_sizeW); } return output; } -Tensor adaptive_avg_pool2d_mps - (at::Tensor const& input, - IntArrayRef output_size) { - +Tensor adaptive_avg_pool2d_mps(at::Tensor const& input, IntArrayRef output_size) { IntArrayRef output_shape; auto osizeH = output_size[0]; @@ -103,7 +99,7 @@ std::vector out_dims = {}; - if(input.ndimension() == 4) { + if (input.ndimension() == 4) { auto sizeB = input.size(0); auto sizeD = input.size(1); @@ -112,8 +108,7 @@ out_dims.push_back(osizeH); out_dims.push_back(osizeW); output_shape = IntArrayRef(out_dims); - } - else { + } else { auto sizeD = input.size(0); out_dims.push_back(sizeD); out_dims.push_back(osizeH); @@ -122,21 +117,12 @@ } const auto memory_format = input.suggest_memory_format(); - Tensor output = at::native::empty_mps( - output_shape, - input.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - memory_format); + Tensor output = + at::native::empty_mps(output_shape, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format); return adaptive_avg_pool2d_out_mps(input, output_size, output); - } -Tensor adaptive_avg_pool2d_backward_mps - (const Tensor& gradOutput, - const Tensor& input) { - +Tensor adaptive_avg_pool2d_backward_mps(const Tensor& gradOutput, const Tensor& input) { int64_t isizeH = input.size(-2); int64_t isizeW = input.size(-1); int64_t osizeH = gradOutput.size(-2); @@ -145,14 +131,11 @@ int64_t strideH = 0, strideW = 0; int64_t kernel_sizeH = 0, kernel_sizeW = 0; - set_kernel_params(isizeH, isizeW, - osizeH, osizeW, - strideH, strideW, - kernel_sizeH, kernel_sizeW, true); + set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true); auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); if (gradInput.numel() != 0) { - if(isizeH >= osizeH) { + if (isizeH >= osizeH) { gradInput = at::avg_pool2d_backward(gradOutput, input, IntArrayRef({kernel_sizeH, kernel_sizeW}), @@ -163,13 +146,13 @@ c10::nullopt); } else { gradInput = at::avg_pool2d(gradOutput, - IntArrayRef({kernel_sizeH, kernel_sizeW}), - IntArrayRef({strideH, strideW}), - IntArrayRef({0, 0}), - false, - true, - c10::nullopt); - gradInput = at::mul(gradInput, kernel_sizeH*kernel_sizeW); + IntArrayRef({kernel_sizeH, kernel_sizeW}), + IntArrayRef({strideH, strideW}), + IntArrayRef({0, 0}), + false, + true, + c10::nullopt); + gradInput = at::mul(gradInput, kernel_sizeH * kernel_sizeW); } } @@ -178,16 +161,16 @@ // Adaptive max pooling TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps) - (const Tensor& input, - IntArrayRef output_size, - const Tensor& output, - const Tensor& indices) { - +(const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) { for (int64_t i = 1; i < input.ndimension(); i++) { TORCH_CHECK(input.size(i) > 0, - "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " - "but input has sizes ", input.sizes(), " with dimension ", i, " being " - "empty"); + "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + "but input has sizes ", + input.sizes(), + " with dimension ", + i, + " being " + "empty"); } int64_t isizeH = input.size(-2); @@ -198,13 +181,11 @@ int64_t strideH = 0, strideW = 0; int64_t kernel_sizeH = 0, kernel_sizeW = 0; - set_kernel_params(isizeH, isizeW, - osizeH, osizeW, - strideH, strideW, - kernel_sizeH, kernel_sizeW); + set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW); at::max_pool2d_with_indices_out(const_cast(output), - const_cast(indices), input, + const_cast(indices), + input, IntArrayRef({kernel_sizeH, kernel_sizeW}), IntArrayRef({strideH, strideW}), IntArrayRef({0, 0}), @@ -213,11 +194,7 @@ } TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps) - (const Tensor& gradOutput, - const Tensor& input, - const Tensor& indices, - const Tensor& gradInput) { - +(const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) { int64_t isizeH = input.size(-2); int64_t isizeW = input.size(-1); int64_t osizeH = gradOutput.size(-2); @@ -226,13 +203,11 @@ int64_t strideH = 0, strideW = 0; int64_t kernel_sizeH = 0, kernel_sizeW = 0; - set_kernel_params(isizeH, isizeW, - osizeH, osizeW, - strideH, strideW, - kernel_sizeH, kernel_sizeW); + set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW); at::max_pool2d_with_indices_backward_out(const_cast(gradInput), - gradOutput, input, + gradOutput, + input, IntArrayRef({kernel_sizeH, kernel_sizeW}), IntArrayRef({strideH, strideW}), IntArrayRef({0, 0}), diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 3953887735631d..f5078a332f0540 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -1,5 +1,5 @@ -#include #include +#include namespace at::native { namespace mps { @@ -124,10 +124,10 @@ kernel void copysign_integral(constant void * input_ [[buffer(0)]], return binaryLibrary; } - NSError *error = nil; - MTLCompileOptions *options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion: MTLLanguageVersion2_3]; - binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_BINARY encoding:NSASCIIStringEncoding] + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BINARY encoding:NSASCIIStringEncoding] options:options error:&error]; TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]); @@ -159,15 +159,15 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) { Tensor other = iter.input(1); Tensor out = iter.output(); - id inputBuffer = getMTLBufferStorage(input); - id otherBuffer = getMTLBufferStorage(other); + id inputBuffer = getMTLBufferStorage(input); + id otherBuffer = getMTLBufferStorage(other); id outputBuffer = getMTLBufferStorage(out); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); const uint32_t nDim = iter.ndim(); constexpr uint32_t nOffsets = 3; const uint32_t numThreads = iter.numel(); - dispatch_sync(mpsStream->queue(), ^(){ + dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { NSError* error = nil; id commandBuffer = mpsStream->commandBuffer(); @@ -177,23 +177,25 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) { std::vector iterShapeData(iterShape.size()); std::vector> strides(nDim); - for (const auto i: c10::irange(iterShape.size())) { + for (const auto i : c10::irange(iterShape.size())) { TORCH_CHECK(i <= UINT32_MAX); iterShapeData[i] = (uint32_t)(iterShape[i]); } - for (const auto i: c10::irange(nDim)) { - for (const auto offset: c10::irange(nOffsets)) { - strides[i][offset] = iter.strides(offset)[i]; + for (const auto i : c10::irange(nDim)) { + for (const auto offset : c10::irange(nOffsets)) { + strides[i][offset] = iter.strides(offset)[i]; } } - id kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); - id kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction - error: &error] autorelease]; - id kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3) - options: 0] autorelease]; - TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + id kernelDataOffsetsFunction = + MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); + id kernelDataOffsetsPSO = + [[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease]; + id kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3) + options:0] autorelease]; + TORCH_CHECK( + kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); [computeEncoder setComputePipelineState:kernelDataOffsetsPSO]; [computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0]; [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1]; @@ -203,28 +205,26 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) { NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup; if (kernelOffsetsTGSize > numThreads) - kernelOffsetsTGSize = numThreads; + kernelOffsetsTGSize = numThreads; MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1); - [computeEncoder dispatchThreads: gridSize - threadsPerThreadgroup: kernelOffsetsThreadGroupSize]; + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize]; const std::string kernel = func_name + "_" + scalarToMetalTypeString(input.scalar_type()); id binaryPSO = binaryPipelineState(device, kernel); [computeEncoder setComputePipelineState:binaryPSO]; - [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; - [computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1]; + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; + [computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2]; [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3]; NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; if (tgSize > numThreads) { - tgSize = numThreads; + tgSize = numThreads; } MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreads: gridSize - threadsPerThreadgroup: threadGroupSize]; + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; [computeEncoder endEncoding]; mpsStream->commit(true); @@ -234,22 +234,22 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) { } // namespace mps void fmax_mps_kernel(TensorIteratorBase& iter) { - if (isFloatingType(iter.common_dtype())) { - mps::binary_mps_impl(iter, "fmax"); - } else { - at::maximum_out(const_cast(iter.output()), iter.input(0), iter.input(1)); - } + if (isFloatingType(iter.common_dtype())) { + mps::binary_mps_impl(iter, "fmax"); + } else { + at::maximum_out(const_cast(iter.output()), iter.input(0), iter.input(1)); + } } void fmin_mps_kernel(TensorIteratorBase& iter) { - if (isFloatingType(iter.common_dtype())) { - mps::binary_mps_impl(iter, "fmin"); - } else { - at::minimum_out(const_cast(iter.output()), iter.input(0), iter.input(1)); - } + if (isFloatingType(iter.common_dtype())) { + mps::binary_mps_impl(iter, "fmin"); + } else { + at::minimum_out(const_cast(iter.output()), iter.input(0), iter.input(1)); + } } void copysign_mps_kernel(TensorIteratorBase& iter) { - mps::binary_mps_impl(iter, "copysign"); + mps::binary_mps_impl(iter, "copysign"); } REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel); diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 2ae120e5abbc86..e16e61fe331912 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -4,34 +4,40 @@ #include #include #include +#include #include -#include #include -#include +#include namespace at::native { namespace mps { -struct BinaryOpCachedGraph : public MPSCachedGraph -{ - BinaryOpCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} +struct BinaryOpCachedGraph : public MPSCachedGraph { + BinaryOpCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor *primaryTensor = nil, *secondaryTensor = nil; MPSGraphTensor *alphaTensor = nil, *outputTensor = nil; }; typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*); -#define BinaryOpFn(graph, primary, secondary) MPSGraphTensor* (mps::BinaryOpCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary) +#define BinaryOpFn(graph, primary, secondary) \ + MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary) // alpha is always 1.0 except when this function is called from add_sub_template() -void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha, - const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock) -{ - TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte ), +void binaryOpTensor(const Tensor& self, + const Tensor& other, + const Scalar& alpha, + const Tensor& output_, + std::string op_name, + BinaryOpBlock binaryBlock) { + TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte), "MPS support binary op with uint8 natively starting from macOS 13.0"); TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) && - (self.scalar_type() == ScalarType::Long || - (other.scalar_type() == ScalarType::Long && (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))), - "MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.2"); + (self.scalar_type() == ScalarType::Long || + (other.scalar_type() == ScalarType::Long && + (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))), + "MPS: ", + op_name, + " op with int64 input is supported natively starting from macOS 13.2"); MPSStream* mpsStream = getCurrentMPSStream(); const bool is_self_scalar = self.dim() == 0; @@ -39,7 +45,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha auto new_size = at::infer_size(self.sizes(), other.sizes()); if (!output_.sizes().equals(new_size)) { - output_.resize_(new_size); + output_.resize_(new_size); } // it's possible to receive empty tensors here @@ -53,7 +59,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha if (!output_.is_contiguous()) { output = output_.contiguous(); needsCopyToOutput = true; - // else, determine if this is an in-place operation on a view output + // else, determine if this is an in-place operation on a view output } else if (output_.is_view() && (self.is_alias_of(output_) || other.is_alias_of(output_))) { output = at::native::empty_mps(output_.sizes(), output_.scalar_type(), c10::nullopt, kMPS); needsCopyToOutput = true; @@ -79,18 +85,20 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = op_name + getTensorsStringKey({self, other, output_}); - BinaryOpCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + BinaryOpCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () { - BinaryOpCachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + BinaryOpCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new BinaryOpCachedGraph(mpsGraph); - newCachedGraph->primaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self)); - newCachedGraph->secondaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other)); + newCachedGraph->primaryTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self)); + newCachedGraph->secondaryTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other)); - MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor; + MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor; MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor; // this type inference is only required at the time of graph creation @@ -99,10 +107,9 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha // integer inputs must be cast to float, if output is float if (isFloatingType(outputDataType)) { common_dtype = outputDataType; - // in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type + // in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type } else if (outputDataType == ScalarType::Bool && - (inputDataType == ScalarType::Byte || - otherDataType == ScalarType::Byte)) { + (inputDataType == ScalarType::Byte || otherDataType == ScalarType::Byte)) { common_dtype = ScalarType::Byte; } } @@ -113,19 +120,19 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); } newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor); - // Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor - // Output tensor should have been promoted but it remains an int32 tensor + // Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to + // int32 tensor Output tensor should have been promoted but it remains an int32 tensor if (outputDataType != common_dtype || - [newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) { + [newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) { newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType); } } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; Placeholder selfPlaceholder; Placeholder otherPlaceholder; MPSScalar self_scalar; @@ -136,16 +143,22 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha self_scalar = getMPSScalar(self.item(), inputDataType); feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self_scalar); } else { - selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, /*mpsShape*/nil, - /*gatherTensorData=*/true, getMPSScalarType(inputDataType)); + selfPlaceholder = Placeholder(cachedGraph->primaryTensor, + self, + /*mpsShape*/ nil, + /*gatherTensorData=*/true, + getMPSScalarType(inputDataType)); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar && !other.is_mps()) { other_scalar = getMPSScalar(other.item(), otherDataType); feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other_scalar); } else { - otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, /*mpsShape*/nil, - /*gatherTensorData=*/true, getMPSScalarType(otherDataType)); + otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, + other, + /*mpsShape*/ nil, + /*gatherTensorData=*/true, + getMPSScalarType(otherDataType)); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } @@ -156,9 +169,8 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha } Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, needsCopyToOutput ? output : output_); - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); if (needsCopyToOutput) { @@ -167,34 +179,35 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha } } -void binaryOpScalar(const Tensor& self, const Scalar& other, const Scalar& alpha, - const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock) -{ +void binaryOpScalar(const Tensor& self, + const Scalar& other, + const Scalar& alpha, + const Tensor& output, + std::string op_name, + BinaryOpBlock binaryBlock) { binaryOpTensor(self, wrapped_scalar_tensor(other), alpha, output, op_name, binaryBlock); } -void div_mode_template(const Tensor& self, const Tensor& other, +void div_mode_template(const Tensor& self, + const Tensor& other, c10::optional rounding_mode, - const Tensor& output, const string op_name) -{ - if(rounding_mode.has_value() && *rounding_mode == "trunc"){ - TORCH_CHECK(self.scalar_type() != ScalarType::Half, - "MPS: does not support trunc_divide op with float16 input"); + const Tensor& output, + const string op_name) { + if (rounding_mode.has_value() && *rounding_mode == "trunc") { + TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input"); } BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0; - if(!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) { - primaryCastTensor = [mpsGraph castTensor:primaryCastTensor - toType:MPSDataTypeFloat32 - name:@"primaryCastTensor"]; + if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) { + primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"]; secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor toType:MPSDataTypeFloat32 name:@"secondaryCastTensor"]; } - MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryCastTensor - name:nil]; + MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor + secondaryTensor:secondaryCastTensor + name:nil]; // Rounding is a no-op for integral types, and also a reasonable workaround // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` // See https://github.com/pytorch/pytorch/issues/84995 @@ -202,14 +215,12 @@ void div_mode_template(const Tensor& self, const Tensor& other, if (!rounding_mode.has_value() || !isFloatOutput) { return divTensor; } else if (*rounding_mode == "trunc") { - auto truncTensor = trunc_tensor(mpsGraph, divTensor); + auto truncTensor = trunc_tensor(mpsGraph, divTensor); if (op_name == "fmod_mps_out") { auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor secondaryTensor:secondaryCastTensor name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor - secondaryTensor:mulTensor - name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil]; } return truncTensor; } else if (*rounding_mode == "floor") { @@ -218,22 +229,28 @@ void div_mode_template(const Tensor& self, const Tensor& other, auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor secondaryTensor:secondaryCastTensor name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor - secondaryTensor:mulTensor - name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil]; } return floorTensor; } assert(0 && "Invalid rounding mode\n"); return nullptr; }; - binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block); + binaryOpTensor(self, + other, + Scalar(1.0), + output, + op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), + div_mode_op_block); } -void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name) -{ +void add_sub_template(const Tensor& self, + const Tensor& other, + const Scalar& alpha, + const Tensor& output, + std::string op_name) { if (alpha.toDouble() == 0.0) { - if (!self.is_alias_of(output)) { // if inplace, no-op + if (!self.is_alias_of(output)) { // if inplace, no-op const_cast(output) = self.clone(); } return; @@ -251,60 +268,79 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp // if alpha is 1.0, then we don't bother adding another multiply to graph if (alpha_has_value) { - cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[@1]); + cachedGraph->alphaTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(other.scalar_type()), @[ @1 ]); secondaryTensor = [mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor secondaryTensor:cachedGraph->alphaTensor name:nil]; } if (op_name == "add") - return [mpsGraph additionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryTensor - name:nil]; + return [mpsGraph additionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil]; else - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryTensor - name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryTensor name:nil]; }; // add alpha's type to the key only if multiply was added to graph - binaryOpTensor(self, other, alpha, output, op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""), add_sub_op_block); + binaryOpTensor(self, + other, + alpha, + output, + op_name + "_out_mps:" + (alpha_has_value ? getMPSTypeString(alpha.type()) : ""), + add_sub_op_block); } } // namespace mps -#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \ -Tensor& func_out (const Tensor& self, const other_type& other, Tensor& output) { \ - mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ - ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ - MPSGraph* mpsGraph = cachedGraph->graph(); \ - return [mpsGraph func_stub##WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \ - secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \ - name:nil]; }); \ - return output; \ -} +#define CREATE_MPS_BINARY_COMPARISON_OP_FUNC(func_out, func_stub, other_type) \ + Tensor& func_out(const Tensor& self, const other_type& other, Tensor& output) { \ + mps::binaryOp##other_type( \ + self, \ + other, \ + Scalar(1.0), \ + output, \ + #func_stub, \ + ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ + MPSGraph* mpsGraph = cachedGraph->graph(); \ + return [mpsGraph func_stub## \ + WithPrimaryTensor:mps::castMPSTensor(mpsGraph, primaryCastTensor, ScalarType::Bool) \ + secondaryTensor:mps::castMPSTensor(mpsGraph, secondaryCastTensor, ScalarType::Bool) \ + name:nil]; \ + }); \ + return output; \ + } -#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \ -TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ - TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && \ - std::string(#func_stub) == "atan2"), \ - "MPS does not support ", #func_stub, " op with int64 input") \ - mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ - ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ - MPSGraph* mpsGraph = cachedGraph->graph(); \ - return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; }); \ -} +#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \ + TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \ + TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && std::string(#func_stub) == "atan2"), \ + "MPS does not support ", \ + #func_stub, \ + " op with int64 input") \ + mps::binaryOp##other_type(self, \ + other, \ + Scalar(1.0), \ + output, \ + #func_stub, \ + ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ + MPSGraph* mpsGraph = cachedGraph->graph(); \ + return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; \ + }); \ + } // output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor() -#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \ -TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ - mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ - ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ - MPSGraph* mpsGraph = cachedGraph->graph(); \ - return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; }); \ -} +#define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \ + TORCH_IMPL_FUNC(func_out)(const Tensor& self, const other_type& other, const Tensor& output) { \ + mps::binaryOp##other_type(self, \ + other, \ + Scalar(1.0), \ + output, \ + #func_stub, \ + ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ + MPSGraph* mpsGraph = cachedGraph->graph(); \ + return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; \ + }); \ + } // Boolean Binary Ops CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(eq_scalar_out_mps, equal, Scalar); @@ -332,24 +368,24 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor); CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor); - -TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& output) { +TORCH_IMPL_FUNC(div_out_mode_mps) +(const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& output) { mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out"); } -TORCH_IMPL_FUNC(div_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { +TORCH_IMPL_FUNC(div_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::div_mode_template(self, other, c10::nullopt, output, "div_out"); } -TORCH_IMPL_FUNC(add_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) { +TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) { mps::add_sub_template(self, other, alpha, output, "add"); } -TORCH_IMPL_FUNC(sub_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) { +TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) { mps::add_sub_template(self, other, alpha, output, "sub"); } -TORCH_IMPL_FUNC(pow_Scalar_out_mps) (const Scalar& base, const Tensor& exp, const Tensor& out) { +TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const Tensor& out) { if (base.equal(1.0)) { out.fill_(1); } else { @@ -386,21 +422,18 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) { return floor_divide_out_mps(self, other, self); } -TORCH_IMPL_FUNC(remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { +TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::div_mode_template(self, other, "floor", output, "remainder_out_mps"); } -TORCH_IMPL_FUNC(fmod_mps_out) (const Tensor& self, const Tensor& other, const Tensor& output) { +TORCH_IMPL_FUNC(fmod_mps_out)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out"); } -TORCH_IMPL_FUNC(hypot_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) -{ +TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 - shape:@[@1] - dataType:primaryCastTensor.dataType]; + MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph powerWithPrimaryTensor:primaryCastTensor secondaryTensor:twoTensor name:nil] @@ -413,46 +446,42 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) { mps::binaryOpTensor(self, other, Scalar(1.0), output, "hypot_out_mps", hypot_op_block); } -TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) -{ +TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil] - secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil] - name:nil]; + MPSGraphTensor* sumTensor = + [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil] + secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil] + name:nil]; return [mpsGraph logarithmWithTensor:sumTensor name:nil]; }; mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block); } -TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) -{ - mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { +TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { + mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil] - secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil] - name:nil]; + MPSGraphTensor* sumTensor = + [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil] + secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil] + name:nil]; return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil]; }; mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block); } -TORCH_IMPL_FUNC(xlogy_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { +TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:primaryCastTensor.dataType]; - MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor - name:nil]; - MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor - name:nil]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; + MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor name:nil]; + MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor name:nil]; MPSGraphTensor* xlogyTensor = [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor secondaryTensor:logyTensor name:nil]; MPSGraphTensor* xEqualZeroPredicateTensor = [mpsGraph equalWithPrimaryTensor:primaryCastTensor - secondaryTensor:zeroTensor - name:nil]; + secondaryTensor:zeroTensor + name:nil]; MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:xEqualZeroPredicateTensor truePredicateTensor:zeroTensor falsePredicateTensor:xlogyTensor diff --git a/aten/src/ATen/native/mps/operations/BitwiseOps.mm b/aten/src/ATen/native/mps/operations/BitwiseOps.mm index 58566b5666d932..868739498366ca 100644 --- a/aten/src/ATen/native/mps/operations/BitwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/BitwiseOps.mm @@ -1,6 +1,6 @@ +#include #include #include -#include #include #include #include @@ -86,17 +86,16 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]], }} )METAL"; - const std::string& getMetalType(const c10::ScalarType& t) { // Mapping from c10::ScalarType to integral type that can be used for bitwise ops // As bitwise ops sign-agnostic map signed/unsigned char and boolean to the same type static std::unordered_map scalar_to_metal_type = { - {c10::ScalarType::Long, "long"}, - {c10::ScalarType::Int, "int"}, - {c10::ScalarType::Short, "short"}, - {c10::ScalarType::Byte, "char"}, - {c10::ScalarType::Char, "char"}, - {c10::ScalarType::Bool, "char"}, + {c10::ScalarType::Long, "long"}, + {c10::ScalarType::Int, "int"}, + {c10::ScalarType::Short, "short"}, + {c10::ScalarType::Byte, "char"}, + {c10::ScalarType::Char, "char"}, + {c10::ScalarType::Bool, "char"}, }; auto it = scalar_to_metal_type.find(t); @@ -112,7 +111,6 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]], return getMetalType(s.type()); } - static id compileBitwiseOpsLibrary(id device, const std::string& t1, const std::string& t2, @@ -123,61 +121,60 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]], if (it != libMap.end()) { return it->second; } - NSError *error = nil; - MTLCompileOptions *options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion: MTLLanguageVersion2_3]; - auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()] - options:options - error:&error]; - TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]); - libMap[key] = rc; - return rc; + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + auto rc = + [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()] + options:options + error:&error]; + TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]); + libMap[key] = rc; + return rc; } - static id getCPLState(id device, - const std::string& t1, - const std::string& t2, - const std::string& t3, - const std::string& fname) { + const std::string& t1, + const std::string& t2, + const std::string& t3, + const std::string& fname) { auto key = t1 + t2 + t3 + fname; static std::unordered_map> cplMap; auto it = cplMap.find(key); if (it != cplMap.end()) { - return it->second; + return it->second; } - NSError *error = nil; + NSError* error = nil; auto library = compileBitwiseOpsLibrary(device, t1, t2, t3); id func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; TORCH_CHECK(func != nil, "Can't get function ", fname); auto rc = [device newComputePipelineStateWithFunction:func error:&error]; - TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); - cplMap[key] = rc; + TORCH_CHECK( + rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); + cplMap[key] = rc; return rc; } -void dispatch1DJob(id commandEncoder, id cplState, uint32_t length) -{ +void dispatch1DJob(id commandEncoder, id cplState, uint32_t length) { uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; auto size = MTLSizeMake(length, 1, 1); auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); - [commandEncoder dispatchThreads:size - threadsPerThreadgroup:threadGroupSize]; + [commandEncoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; } -void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& other, at::Tensor& output, const std::string& kernel_name) { +void handle_tensor_tensor_binary_op(const at::Tensor& self, + const at::Tensor& other, + at::Tensor& output, + const std::string& kernel_name) { using namespace at::mps; MPSStream* stream = getCurrentMPSStream(); - id cplState = getCPLState(MPSDevice::getInstance()->device(), - getMetalType(output), - getMetalType(self), - getMetalType(other), - kernel_name); + id cplState = getCPLState( + MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name); uint32_t length = output.numel(); if (length == 0) { return; } - dispatch_sync(stream->queue(), ^(){ + dispatch_sync(stream->queue(), ^() { id buffer = stream->commandBuffer(); id commandEncoder = [buffer computeCommandEncoder]; @@ -188,29 +185,29 @@ void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& ot [commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]]; [commandEncoder setComputePipelineState:cplState]; [commandEncoder setBytes:&length length:sizeof(length) atIndex:0]; - [commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1]; - [commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2]; - [commandEncoder setBuffer:otherBuf offset:other.storage_offset()*other.itemsize() atIndex:3]; + [commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1]; + [commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2]; + [commandEncoder setBuffer:otherBuf offset:other.storage_offset() * other.itemsize() atIndex:3]; dispatch1DJob(commandEncoder, cplState, length); [commandEncoder endEncoding]; stream->commit(true); }); } -void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& other, at::Tensor& output, const std::string& kernel_name) { +void handle_tensor_scalar_binary_op(const at::Tensor& self, + const at::Scalar& other, + at::Tensor& output, + const std::string& kernel_name) { using namespace at::mps; MPSStream* stream = getCurrentMPSStream(); - id cplState = getCPLState(MPSDevice::getInstance()->device(), - getMetalType(output), - getMetalType(self), - getMetalType(other), - kernel_name); + id cplState = getCPLState( + MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name); uint64_t sval = other.to(); uint32_t length = output.numel(); if (length == 0) { return; } - dispatch_sync(stream->queue(), ^(){ + dispatch_sync(stream->queue(), ^() { id buffer = stream->commandBuffer(); id commandEncoder = [buffer computeCommandEncoder]; @@ -220,8 +217,8 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot [commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]]; [commandEncoder setComputePipelineState:cplState]; [commandEncoder setBytes:&length length:sizeof(length) atIndex:0]; - [commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1]; - [commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2]; + [commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1]; + [commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2]; [commandEncoder setBytes:&sval length:sizeof(sval) atIndex:3]; dispatch1DJob(commandEncoder, cplState, length); [commandEncoder endEncoding]; @@ -229,7 +226,10 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot }); } -at::Tensor& _bitwise_op_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output_, const std::string& op_name) { +at::Tensor& _bitwise_op_out_mps(const at::Tensor& self, + const at::Tensor& other, + at::Tensor& output_, + const std::string& op_name) { using namespace at::mps; const bool is_self_scalar = self.dim() == 0; const bool is_other_scalar = other.dim() == 0; @@ -264,24 +264,24 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot fmt::format("bitwise_{}_tensor", op_name)); } if (needs_output_copy) { - output_.copy_(output); + output_.copy_(output); } return output_; } -at::Tensor& bitwise_and_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) { - return _bitwise_op_out_mps(self, other, output, "and"); +at::Tensor& bitwise_and_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) { + return _bitwise_op_out_mps(self, other, output, "and"); } -at::Tensor& bitwise_or_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) { - return _bitwise_op_out_mps(self, other, output, "or"); +at::Tensor& bitwise_or_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) { + return _bitwise_op_out_mps(self, other, output, "or"); } -at::Tensor& bitwise_xor_out_mps (const at::Tensor& self, const at::Tensor& other, at::Tensor& output) { - return _bitwise_op_out_mps(self, other, output, "xor"); +at::Tensor& bitwise_xor_out_mps(const at::Tensor& self, const at::Tensor& other, at::Tensor& output) { + return _bitwise_op_out_mps(self, other, output, "xor"); } -at::Tensor& bitwise_not_out_mps (const at::Tensor& self, at::Tensor& output_) { +at::Tensor& bitwise_not_out_mps(const at::Tensor& self, at::Tensor& output_) { // Handle boolean tensor using logical not if (self.scalar_type() == c10::ScalarType::Bool) { return at::native::logical_not_out_mps(self, output_); @@ -310,12 +310,9 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot } using namespace at::mps; MPSStream* stream = getCurrentMPSStream(); - id cplState = getCPLState(MPSDevice::getInstance()->device(), - getMetalType(output), - getMetalType(self), - getMetalType(self), - "bitwise_not"); - dispatch_sync(stream->queue(), ^(){ + id cplState = getCPLState( + MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(self), "bitwise_not"); + dispatch_sync(stream->queue(), ^() { id buffer = stream->commandBuffer(); id commandEncoder = [buffer computeCommandEncoder]; @@ -325,20 +322,18 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot [commandEncoder pushDebugGroup:@"Dispatch bitwise_not kernel"]; [commandEncoder setComputePipelineState:cplState]; [commandEncoder setBytes:&length length:sizeof(length) atIndex:0]; - [commandEncoder setBuffer:outBuf offset:output.storage_offset()*output.itemsize() atIndex:1]; - [commandEncoder setBuffer:selfBuf offset:self.storage_offset()*self.itemsize() atIndex:2]; + [commandEncoder setBuffer:outBuf offset:output.storage_offset() * output.itemsize() atIndex:1]; + [commandEncoder setBuffer:selfBuf offset:self.storage_offset() * self.itemsize() atIndex:2]; dispatch1DJob(commandEncoder, cplState, length); [commandEncoder endEncoding]; stream->commit(true); }); if (needs_output_copy) { - output_.copy_(output); + output_.copy_(output); } return output_; } - - TORCH_LIBRARY_IMPL(aten, MPS, m) { m.impl("bitwise_and.Tensor_out", bitwise_and_out_mps); m.impl("bitwise_or.Tensor_out", bitwise_or_out_mps); diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index 99a2bf8e6c1304..0f6bbae1c4ff8b 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -12,26 +12,19 @@ #include #endif - namespace at::native { - -Tensor dot_mps( - const Tensor &self, - const Tensor &other) -{ - +Tensor dot_mps(const Tensor& self, const Tensor& other) { TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS: dot op doesn't support int64 input") using namespace mps; auto output = at::native::empty_mps({}, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* selfTensor_ = nil; - MPSGraphTensor* otherTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* selfTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -40,45 +33,38 @@ Tensor dot_mps( @autoreleasepool { string key = "dot_mps" + getTensorsStringKey({self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); + MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor *castSelf = nil; - MPSGraphTensor *castOther = nil; + MPSGraphTensor* castSelf = nil; + MPSGraphTensor* castOther = nil; - if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte - || self.scalar_type() == ScalarType::Char) { - castSelf = [mpsGraph castTensor:selfTensor - toType:MPSDataTypeInt32 - name:@"castSelfTensor"]; - castOther = [mpsGraph castTensor:otherTensor - toType:MPSDataTypeInt32 - name:@"castOtherTensor"]; + if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte || + self.scalar_type() == ScalarType::Char) { + castSelf = [mpsGraph castTensor:selfTensor toType:MPSDataTypeInt32 name:@"castSelfTensor"]; + castOther = [mpsGraph castTensor:otherTensor toType:MPSDataTypeInt32 name:@"castOtherTensor"]; } else { castSelf = selfTensor; castOther = otherTensor; } - MPSGraphTensor *dot = [mpsGraph multiplicationWithPrimaryTensor: castSelf - secondaryTensor: castOther - name: @"multiplication"]; + MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf + secondaryTensor:castOther + name:@"multiplication"]; - MPSGraphTensor *dotProductTensor = [mpsGraph reductionSumWithTensor: dot - axes: nil - name: @"dotProduct"]; + MPSGraphTensor* dotProductTensor = [mpsGraph reductionSumWithTensor:dot axes:nil name:@"dotProduct"]; - if(self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte - || self.scalar_type() == ScalarType::Char) + if (self.scalar_type() == ScalarType::Short || self.scalar_type() == ScalarType::Byte || + self.scalar_type() == ScalarType::Char) dotProductTensor = [mpsGraph castTensor:dotProductTensor toType:getMPSDataType(self) name:@"castDotProductTensor"]; @@ -89,7 +75,7 @@ Tensor dot_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); @@ -101,9 +87,8 @@ Tensor dot_mps( otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData(), }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -111,14 +96,12 @@ Tensor dot_mps( return output; } -Tensor& addmv_out_mps_impl( - const Tensor &self, - const Tensor &mat, - const Tensor &vec, - const Scalar& beta_, - const Scalar& alpha_, - Tensor& result) -{ +Tensor& addmv_out_mps_impl(const Tensor& self, + const Tensor& mat, + const Tensor& vec, + const Scalar& beta_, + const Scalar& alpha_, + Tensor& result) { using namespace mps; TORCH_CHECK(mat.is_mps()); @@ -129,38 +112,35 @@ Tensor dot_mps( c10::MaybeOwned self_ = expand_size(self, {mat.size(0)}); auto betaval = beta_.toComplexDouble(); - struct CachedGraph : public mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *selfTensor_ = nil; - MPSGraphTensor *matMulVecTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* selfTensor_ = nil; + MPSGraphTensor* matMulVecTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance(); - MPSStream *stream = at::mps::getCurrentMPSStream(); + MPSStream* stream = at::mps::getCurrentMPSStream(); Tensor matMulVec = mm(mat, vec.unsqueeze(1)).squeeze(1); @autoreleasepool { - string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) - + ":" + to_string(beta_.toDouble()) - + ":" + to_string(alpha_.toDouble()); + string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) + + ":" + to_string(alpha_.toDouble()); CachedGraph* cachedGraph = nil; - if(!cachedGraph) { + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec); - MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* matMulVecTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, matMulVec); + MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); // Intermediates for beta and alpha - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha_.toDouble() - dataType: getMPSScalarType(mat.scalar_type())]; + MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha_.toDouble() + dataType:getMPSScalarType(mat.scalar_type())]; // Intermediates for multiplying by beta and alpha MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:matMulVecTensor @@ -168,18 +148,17 @@ Tensor dot_mps( name:@"MM/alpha*(mat@vec)"]; newCachedGraph->outputTensor_ = productTimesAlphaTensor; - if (betaval != 0.0) - { - MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta_.toDouble() - dataType: getMPSScalarType(self.scalar_type())]; + if (betaval != 0.0) { + MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta_.toDouble() + dataType:getMPSScalarType(self.scalar_type())]; - MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: selfTensor - secondaryTensor: betaTensor - name: @"MM/beta*input"]; + MPSGraphTensor* selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor + secondaryTensor:betaTensor + name:@"MM/beta*input"]; - MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor: productTimesAlphaTensor - secondaryTensor: selfTimesBetaTensor - name: @"MM/beta*input + alpha*(mat@vec)"]; + MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor + secondaryTensor:selfTimesBetaTensor + name:@"MM/beta*input + alpha*(mat@vec)"]; newCachedGraph->outputTensor_ = outputTensor; } @@ -189,23 +168,21 @@ Tensor dot_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder matMulVecPlaceholder = Placeholder(cachedGraph->matMulVecTensor_, matMulVec); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; - feeds[matMulVecPlaceholder.getMPSGraphTensor()] = matMulVecPlaceholder.getMPSGraphTensorData(); - if (betaval != 0.0) - { - Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); - feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + feeds[matMulVecPlaceholder.getMPSGraphTensor()] = matMulVecPlaceholder.getMPSGraphTensorData(); + if (betaval != 0.0) { + Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); + feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -213,7 +190,13 @@ Tensor dot_mps( return result; } -TORCH_IMPL_FUNC(addmv_out_mps)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) { +TORCH_IMPL_FUNC(addmv_out_mps) +(const Tensor& self, + const Tensor& mat, + const Tensor& vec, + const Scalar& beta_, + const Scalar& alpha_, + const Tensor& result) { addmv_out_mps_impl(self, mat, vec, beta_, alpha_, const_cast(result)); } diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index 12e86e14c6357c..dc12d425661ab1 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -18,26 +18,27 @@ } struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* outputTensor_ = nil; }; - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble()); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool{ - MPSGraph *mpsGraph = make_mps_graph(); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); auto isBool = self.scalar_type() == c10::ScalarType::Bool; auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte; - auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32; + auto dataType = + !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32; // constantWithScalar does not work for boolTypes on MacOS-12.[34] // workaround by filing it as int8 tensor and than casting to bool // See https://github.com/pytorch/pytorch/issues/82427 @@ -47,17 +48,12 @@ MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble() shape:getMPSShape(self) dataType:dataType]; - MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil]; if (isBool) { - outputTensor = [mpsGraph castTensor:outputTensor - toType:MPSDataTypeBool - name:@"constWithBool-workaround"]; + outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"]; } if (isUInt8) { - outputTensor = [mpsGraph castTensor:outputTensor - toType:MPSDataTypeUInt8 - name:@"constWithUInt8-workaround"]; + outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"]; } newCachedGraph->outputTensor_ = outputTensor; @@ -66,13 +62,11 @@ }); } - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, - needsCopyToOutput ? output : self, - nullptr, !needsCopyToOutput); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput); - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results); @@ -109,7 +103,10 @@ bool fill_mps_tensor_(Tensor& self, uint8_t value) { } Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) { - TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions."); + TORCH_CHECK(value.dim() == 0, + "fill_ only supports 0-dimension value tensor but got tensor with ", + value.dim(), + " dimensions."); Scalar scalar_value = value.item(); if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) return self; diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 008c1781a576e7..3d929b48cf6c42 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -2,39 +2,51 @@ #include #include -#include #include +#include #include -#include #include +#include #include namespace at::native { void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* descriptor_, - NSUInteger strideInX, NSUInteger strideInY, - NSUInteger dilationRateInX, NSUInteger dilationRateInY, - NSUInteger paddingHorizontal, NSUInteger paddingVertical, - c10::MemoryFormat memory_format, NSUInteger groups) { - descriptor_.strides = @[@1, [[NSNumber alloc] initWithInteger: strideInY], - [[NSNumber alloc] initWithInteger: strideInX]]; - descriptor_.dilationRates = @[@1, [[NSNumber alloc] initWithInteger: dilationRateInY], - [[NSNumber alloc] initWithInteger: dilationRateInX]]; + NSUInteger strideInX, + NSUInteger strideInY, + NSUInteger dilationRateInX, + NSUInteger dilationRateInY, + NSUInteger paddingHorizontal, + NSUInteger paddingVertical, + c10::MemoryFormat memory_format, + NSUInteger groups) { + descriptor_.strides = + @[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ]; + descriptor_.dilationRates = + @[ @1, [[NSNumber alloc] initWithInteger:dilationRateInY], [[NSNumber alloc] initWithInteger:dilationRateInX] ]; descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit; - descriptor_.paddingValues = @[@0, @0, [[NSNumber alloc] initWithInteger: paddingVertical], [[NSNumber alloc] - initWithInteger: paddingVertical], [[NSNumber alloc] - initWithInteger: paddingHorizontal], [[NSNumber alloc] - initWithInteger: paddingHorizontal]]; + descriptor_.paddingValues = @[ + @0, + @0, + [[NSNumber alloc] initWithInteger:paddingVertical], + [[NSNumber alloc] initWithInteger:paddingVertical], + [[NSNumber alloc] initWithInteger:paddingHorizontal], + [[NSNumber alloc] initWithInteger:paddingHorizontal] + ]; descriptor_.channelDimensionIndex = -3LL; } // Create convolution descriptor void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, - NSUInteger strideInX, NSUInteger strideInY, - NSUInteger dilationRateInX, NSUInteger dilationRateInY, - NSUInteger paddingHorizontal, NSUInteger paddingVertical, - c10::MemoryFormat memory_format, NSUInteger groups) { + NSUInteger strideInX, + NSUInteger strideInY, + NSUInteger dilationRateInX, + NSUInteger dilationRateInY, + NSUInteger paddingHorizontal, + NSUInteger paddingVertical, + c10::MemoryFormat memory_format, + NSUInteger groups) { descriptor_.strideInX = strideInX; descriptor_.strideInY = strideInY; descriptor_.dilationRateInX = dilationRateInX; @@ -48,64 +60,59 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, descriptor_.paddingTop = paddingVertical; descriptor_.paddingBottom = paddingVertical; - descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? - MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC; + descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW + : MPSGraphTensorNamedDataLayoutNHWC; // PyTorch always uses OIHW memory layout for weights descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW; descriptor_.groups = groups; } -Tensor _mps_convolution_impl( - const Tensor& input_t, - const Tensor& weight_t, - const c10::optional& bias_opt, - IntArrayRef padding, - IntArrayRef stride, - IntArrayRef dilation, - int64_t groups, - c10::optional input_shape) { +Tensor _mps_convolution_impl(const Tensor& input_t, + const Tensor& weight_t, + const c10::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + c10::optional input_shape) { TORCH_CHECK(input_t.dim() < 5, "Conv3D is not supported on MPS"); TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); namespace native_mps = at::native::mps; CheckedFrom c = "mps_convolution"; - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }; + TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2}; checkAllSameType(c, {input, weight}); checkAllSameGPU(c, {input, weight}); bool bias_defined; - if(bias_opt == c10::nullopt) + if (bias_opt == c10::nullopt) bias_defined = false; else - bias_defined = bias_opt->defined(); + bias_defined = bias_opt->defined(); auto memory_format = input_t.suggest_memory_format(); bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); - auto output_t = at::empty( - input_shape.has_value() ? - input_shape.value() : - conv_output_size(input->sizes(), weight->sizes(), - padding, stride, dilation), - input->scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + auto output_t = + at::empty(input_shape.has_value() ? input_shape.value() + : conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation), + input->scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); if (output_t.numel() == 0) { return output_t; } - TensorArg output{ output_t, "result", 0 }; + TensorArg output{output_t, "result", 0}; convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* biasTensor_ = nil; MPSGraphTensor* weightTensor_ = nil; @@ -117,13 +124,12 @@ Tensor _mps_convolution_impl( auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - IntArrayRef bias_shape; - if(bias_defined) + if (bias_defined) bias_shape = bias_opt.value().sizes(); string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -135,76 +141,87 @@ Tensor _mps_convolution_impl( } string bias_shape_key; - if(bias_defined) { + if (bias_defined) { bias_shape_key = to_string(bias_shape[0]); } else { bias_shape_key = "nobias"; } - string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" - + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" - + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" - + to_string(groups) + ":" + mem_format_key - + mps::getTensorsStringKey({input_t, weight_t}) + ":" - + to_string(bias_defined) + ":" + bias_shape_key; - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) + + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + + to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + + to_string(bias_defined) + ":" + bias_shape_key; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ =[[MPSGraphConvolution2DOpDescriptor new] autorelease]; - MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; + MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; + MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = + [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; MPSShape* weightShape = mps::getMPSShape(weight_t); - bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && - inputShape.count >= 4 && weightShape.count >= 4 && !is_channels_last); - if(isDepthwiseConv) { - fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0], - dilation[1], dilation[0], - padding[1], padding[0], - memory_format, groups); + bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 && + weightShape.count >= 4 && !is_channels_last); + if (isDepthwiseConv) { + fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, + stride[1], + stride[0], + dilation[1], + dilation[0], + padding[1], + padding[0], + memory_format, + groups); } else { - fill_conv_desc(conv2dDescriptor_, stride[1], stride[0], - dilation[1], dilation[0], - padding[1], padding[0], - memory_format, groups); + fill_conv_desc(conv2dDescriptor_, + stride[1], + stride[0], + dilation[1], + dilation[0], + padding[1], + padding[0], + memory_format, + groups); } - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape); MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t); MPSGraphTensor* biasTensor = nil; - if(bias_defined) { - biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value())); + if (bias_defined) { + biasTensor = + native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value())); } MPSGraphTensor* outputTensor; - if(isDepthwiseConv) { - MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil]; - outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor: inputTensor - weightsTensor: weightTransposeTensor - descriptor: depthWiseConv3dDescriptor_ - name: nil]; + if (isDepthwiseConv) { + MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor + dimension:-3 + withDimension:-4 + name:nil]; + outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:inputTensor + weightsTensor:weightTransposeTensor + descriptor:depthWiseConv3dDescriptor_ + name:nil]; } else { - outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor - weightsTensor: weightTensor - descriptor: conv2dDescriptor_ - name: nil]; + outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor + weightsTensor:weightTensor + descriptor:conv2dDescriptor_ + name:nil]; } if (is_channels_last) { outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor); } - if(bias_defined) { - outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor - secondaryTensor: biasTensor - name: nil]; + if (bias_defined) { + outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil]; } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->weightTensor_ = weightTensor; @@ -213,27 +230,28 @@ Tensor _mps_convolution_impl( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape); auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t); auto biasPlaceholder = native_mps::Placeholder(); // Reshape the bias to be broadcastable with output of conv2d - if(bias_defined) - biasPlaceholder = native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1})); + if (bias_defined) + biasPlaceholder = + native_mps::Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1})); auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, *output); - NSMutableDictionary* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 3] autorelease]; + NSMutableDictionary* feeds = + [[[NSMutableDictionary alloc] initWithCapacity:3] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData(); - if(bias_defined) { + if (bias_defined) { feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); } - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -241,40 +259,42 @@ Tensor _mps_convolution_impl( return *output; } -Tensor _mps_convolution( - const Tensor& input_t, - const Tensor& weight_t, - const c10::optional& bias_opt, - IntArrayRef padding, - IntArrayRef stride, - IntArrayRef dilation, - int64_t groups) { - return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt); +Tensor _mps_convolution(const Tensor& input_t, + const Tensor& weight_t, + const c10::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups) { + return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt); } -Tensor mps_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { +Tensor mps_convolution_backward_input(IntArrayRef input_size, + const Tensor& grad_output_t, + const Tensor& weight_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool bias_defined) { namespace native_mps = at::native::mps; using namespace mps; TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types"); CheckedFrom c = "mps_convolution_backward_input"; - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - weight{ weight_t, "weight", 2 }; + TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2}; checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); auto memory_format = grad_output_t.suggest_memory_format(); bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); - auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt); + auto grad_input_t = at::empty(input_size, grad_output_t.options(), c10::nullopt); // Avoid "grad_input" when this is being used as transposed convolution - TensorArg grad_input{ grad_input_t, "result", 0 }; + TensorArg grad_input{grad_input_t, "result", 0}; convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* weightTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; @@ -284,11 +304,10 @@ Tensor mps_convolution_backward_input( // Add backward with input @autoreleasepool { - MPSStream* stream = getCurrentMPSStream(); string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -302,64 +321,77 @@ Tensor mps_convolution_backward_input( MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format); MPSShape* mps_input_shape = getMPSShape(input_size); NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" - + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" - + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" - + to_string(groups) + ":" + mem_format_key - + getTensorsStringKey({grad_output_t, weight_t}) + ":" - + string([ns_shape_key UTF8String]); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - + string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + + to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + + getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if (!cachedGraph) { + native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() { CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; - MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; + MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; + MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = + [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; MPSShape* weightOutputShape = mps::getMPSShape(weight_t); // Depthwise conv is input feature channels = groups. So I in OIHW has to be 1. - bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && - gradOutputShape.count >= 4 && weightOutputShape.count >= 4 && !is_channels_last); - - if(isDepthwiseConv) { - fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0], - dilation[1], dilation[0], - padding[1], padding[0], - at::MemoryFormat::Contiguous, groups); + bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 && + weightOutputShape.count >= 4 && !is_channels_last); + + if (isDepthwiseConv) { + fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, + stride[1], + stride[0], + dilation[1], + dilation[0], + padding[1], + padding[0], + at::MemoryFormat::Contiguous, + groups); } else { - fill_conv_desc(conv2dDescriptor_, stride[1], stride[0], - dilation[1], dilation[0], - padding[1], padding[0], - at::MemoryFormat::Contiguous, groups); + fill_conv_desc(conv2dDescriptor_, + stride[1], + stride[0], + dilation[1], + dilation[0], + padding[1], + padding[0], + at::MemoryFormat::Contiguous, + groups); } - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); + MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t); - MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor; + MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor; if (is_channels_last) { gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose); } MPSGraphTensor* gradInputTensor; - if(isDepthwiseConv) { - MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-3 withDimension:-4 name:nil]; - gradInputTensor = [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose - weightsTensor:weightTransposeTensor - outputShape:mps_input_shape - descriptor:depthWiseConv3dDescriptor_ - name:nil]; + if (isDepthwiseConv) { + MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor + dimension:-3 + withDimension:-4 + name:nil]; + gradInputTensor = + [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose + weightsTensor:weightTransposeTensor + outputShape:mps_input_shape + descriptor:depthWiseConv3dDescriptor_ + name:nil]; } else { - gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose - weightsTensor:weightTensor - outputShape:mps_input_shape - forwardConvolutionDescriptor:conv2dDescriptor_ - name:nil]; + gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose + weightsTensor:weightTensor + outputShape:mps_input_shape + forwardConvolutionDescriptor:conv2dDescriptor_ + name:nil]; } newCachedGraph->gradOutputTensor_ = gradOutputTensor; @@ -368,30 +400,34 @@ Tensor mps_convolution_backward_input( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), weightsPlaceholder.getMPSGraphTensor() : weightsPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return *grad_input; } -Tensor mps_convolution_backward_weights( - IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { +Tensor mps_convolution_backward_weights(IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool bias_defined) { namespace native_mps = at::native::mps; using namespace mps; TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types"); @@ -403,27 +439,21 @@ Tensor mps_convolution_backward_weights( // For uniformity with everything else, although it seems grad_weight // would be unambiguous too. - TensorArg grad_output{ grad_output_t, "grad_output", 1 }; - TensorArg input{ input_t, "input", 2}; + TensorArg grad_output{grad_output_t, "grad_output", 1}; + TensorArg input{input_t, "input", 2}; checkAllSameType(c, {grad_output, input}); checkAllSameGPU(c, {grad_output, input}); - auto grad_weight_t = at::empty( - weight_size, - grad_output_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - TensorArg grad_weight{ grad_weight_t, "result", 0 }; + auto grad_weight_t = + at::empty(weight_size, grad_output_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); + TensorArg grad_weight{grad_weight_t, "result", 0}; convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups); // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* gradWeightTensor_ = nil; @@ -432,11 +462,10 @@ Tensor mps_convolution_backward_weights( native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); @autoreleasepool { - MPSStream* stream = getCurrentMPSStream(); string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -448,64 +477,79 @@ Tensor mps_convolution_backward_weights( } MPSShape* mps_weight_shape = getMPSShape(weight_size); NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" - + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" - + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" - + to_string(groups) + ":" + mem_format_key - + getTensorsStringKey({grad_output_t, input_t}) + ":" - + string([ns_shape_key UTF8String]); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - + string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + + to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + + getTensorsStringKey({grad_output_t, input_t}) + ":" + string([ns_shape_key UTF8String]); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if (!cachedGraph) { + native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() { CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphConvolution2DOpDescriptor *conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; - MPSGraphDepthwiseConvolution3DOpDescriptor *depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; + MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; + MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = + [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; MPSShape* inputShape = mps::getMPSShape(input_t); - bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 && mps_weight_shape.count >= 4 && !is_channels_last); - - if(isDepthwiseConv) { - fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, stride[1], stride[0], - dilation[1], dilation[0], - padding[1], padding[0], - at::MemoryFormat::Contiguous, groups); + bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 && + mps_weight_shape.count >= 4 && !is_channels_last); + + if (isDepthwiseConv) { + fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, + stride[1], + stride[0], + dilation[1], + dilation[0], + padding[1], + padding[0], + at::MemoryFormat::Contiguous, + groups); } else { - fill_conv_desc(conv2dDescriptor_, stride[1], stride[0], - dilation[1], dilation[0], - padding[1], padding[0], - at::MemoryFormat::Contiguous, groups); + fill_conv_desc(conv2dDescriptor_, + stride[1], + stride[0], + dilation[1], + dilation[0], + padding[1], + padding[0], + at::MemoryFormat::Contiguous, + groups); } - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); + MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor; + MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor; if (is_channels_last) { gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose); } MPSGraphTensor* gradWeightTensor; - if(isDepthwiseConv) { - NSNumber* outputFeatChannelDim = mps_weight_shape[0]; - MPSShape* weightShapeTranspose = @[@1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3]]; - MPSGraphTensor* gradWeightTensorTranspose = [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose - sourceTensor:inputTensor - outputShape:weightShapeTranspose - descriptor:depthWiseConv3dDescriptor_ - name:nil]; - gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose dimension:-3 withDimension:-4 name:nil]; + if (isDepthwiseConv) { + NSNumber* outputFeatChannelDim = mps_weight_shape[0]; + MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ]; + MPSGraphTensor* gradWeightTensorTranspose = + [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose + sourceTensor:inputTensor + outputShape:weightShapeTranspose + descriptor:depthWiseConv3dDescriptor_ + name:nil]; + gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose + dimension:-3 + withDimension:-4 + name:nil]; } else { - gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose - sourceTensor:inputTensor - outputShape:mps_weight_shape - forwardConvolutionDescriptor:conv2dDescriptor_ - name:nil]; + gradWeightTensor = + [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose + sourceTensor:inputTensor + outputShape:mps_weight_shape + forwardConvolutionDescriptor:conv2dDescriptor_ + name:nil]; } newCachedGraph->gradOutputTensor_ = gradOutputTensor; newCachedGraph->inputTensor_ = inputTensor; @@ -513,21 +557,20 @@ Tensor mps_convolution_backward_weights( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -535,10 +578,14 @@ Tensor mps_convolution_backward_weights( return grad_weight_t; } -std::tuple mps_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - std::array output_mask) { +std::tuple mps_convolution_backward(const at::Tensor& input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::array output_mask) { Tensor grad_input, grad_weight, grad_bias; if (input.numel() == 0) { if (output_mask[0]) { @@ -549,73 +596,85 @@ Tensor mps_convolution_backward_weights( } } else { if (output_mask[0]) { - grad_input = mps_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]); + grad_input = mps_convolution_backward_input( + input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]); } if (output_mask[1]) { - grad_weight = mps_convolution_backward_weights(weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]); + grad_weight = mps_convolution_backward_weights( + weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]); } } - return std::tuple{grad_input, grad_weight, grad_bias}; + return std::tuple{grad_input, grad_weight, grad_bias}; } -Tensor mps_convolution_transpose_forward( - const Tensor& grad_output, const Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) -{ - auto input_size = conv_input_size(grad_output.sizes(), weight.sizes(), - padding, output_padding, stride, dilation, groups); - return mps_convolution_backward_input(input_size, grad_output, weight, - padding, stride, dilation, groups, false); +Tensor mps_convolution_transpose_forward(const Tensor& grad_output, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups) { + auto input_size = + conv_input_size(grad_output.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); + return mps_convolution_backward_input(input_size, grad_output, weight, padding, stride, dilation, groups, false); } -Tensor _mps_convolution_transpose( - const Tensor& input_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups) { +Tensor _mps_convolution_transpose(const Tensor& input_t, + const Tensor& weight_t, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups) { TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS"); - auto output_t = mps_convolution_transpose_forward( - input_t, weight_t, padding, output_padding, stride, dilation, groups); + auto output_t = + mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups); return output_t; - } -Tensor mps_convolution_transpose_backward_input( - const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, IntArrayRef input_shape) -{ - return _mps_convolution_impl( - grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape); +Tensor mps_convolution_transpose_backward_input(const Tensor& grad_output_t, + const Tensor& weight_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + IntArrayRef input_shape) { + return _mps_convolution_impl(grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape); } -Tensor mps_convolution_transpose_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) -{ +Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups) { return mps_convolution_backward_weights( - weight_size, input_t, grad_output_t, - padding, stride, dilation, groups, false); + weight_size, input_t, grad_output_t, padding, stride, dilation, groups, false); } - -std::tuple mps_convolution_transpose_backward( - const Tensor& input, const Tensor& grad_output, const Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - std::array output_mask) { +std::tuple mps_convolution_transpose_backward(const Tensor& input, + const Tensor& grad_output, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + std::array output_mask) { Tensor grad_input, grad_weight; if (output_mask[0]) { - grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes()); + grad_input = + mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes()); } if (output_mask[1]) { - grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups); + grad_weight = mps_convolution_transpose_backward_weight( + weight.sizes(), grad_output, input, padding, stride, dilation, groups); } - return std::tuple{grad_input, grad_weight}; + return std::tuple{grad_input, grad_weight}; } - } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 38b0da31670b9b..b9ebe34c412f10 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -6,10 +6,7 @@ namespace at::native { namespace mps { -void* pageAlignedBlockPtr( - const void* ptr, - NSUInteger size, - NSUInteger* alignedBlockSize) { +void* pageAlignedBlockPtr(const void* ptr, NSUInteger size, NSUInteger* alignedBlockSize) { uintptr_t address = (uintptr_t)ptr; uintptr_t alignedAddress = address & ~(PAGE_SIZE - 1); uintptr_t alignedEnd = ((address + size) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1); @@ -26,15 +23,15 @@ * Computes number of elements one needs to transfer to preserve all the elements */ size_t compute_strided_size(const at::Tensor& t) { - size_t rc = 1; - if (t.numel() == 0) { - return 0; - } - for(const auto i: c10::irange(t.dim())) { - assert(t.size(i) > 0); - rc += (t.size(i) - 1) * t.stride(i); - } - return rc; + size_t rc = 1; + if (t.numel() == 0) { + return 0; + } + for (const auto i : c10::irange(t.dim())) { + assert(t.size(i) > 0); + rc += (t.size(i) - 1) * t.stride(i); + } + return rc; } bool is_strided_contiguous(const at::Tensor& t) { @@ -43,13 +40,15 @@ bool is_strided_contiguous(const at::Tensor& t) { // Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type(). // The shapes and dtypes are taken from dst and src, but their storage pointers are not used. -void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, - id destBuffer, id sourceBuffer, bool non_blocking = true) { +void copy_cast_mps(at::Tensor& dst, + const at::Tensor& src, + id destBuffer, + id sourceBuffer, + bool non_blocking = true) { using namespace mps; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; }; @@ -64,11 +63,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, @autoreleasepool { string key = "copy_cast_mps" + getTensorsStringKey({src, dst}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -85,23 +84,24 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc] - initWithMTLBuffer:sourceBuffer shape:srcShape dataType:srcDType] - autorelease]; - MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc] - initWithMTLBuffer:destBuffer shape:dstShape dataType:dstDType] - autorelease]; - NSDictionary* feeds = @{cachedGraph->inputTensor_: srcData}; - NSDictionary* results = @{cachedGraph->outputTensor_: dstData}; - stream->executeMPSGraph(cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE); + MPSGraphTensorData* srcData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer + shape:srcShape + dataType:srcDType] autorelease]; + MPSGraphTensorData* dstData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:destBuffer + shape:dstShape + dataType:dstDType] autorelease]; + NSDictionary* feeds = @{cachedGraph->inputTensor_ : srcData}; + NSDictionary* results = @{cachedGraph->outputTensor_ : dstData}; + stream->executeMPSGraph( + cachedGraph->graph(), feeds, results, !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE); } } -static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) -{ - auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format()); +static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) { + auto sameMemFormat = + src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format()); id device = MPSDevice::getInstance()->device(); MPSStream* stream = getCurrentMPSStream(); @@ -152,8 +152,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, needsBlit = false; tmpBuffer = destBuffer; } else if (src.element_size() < dst.element_size()) { - tmp = at::native::empty_mps(dst.sizes(), dst.scalar_type(), c10::nullopt, kMPS); - tmpBuffer = getMTLBufferStorage(tmp); + tmp = at::native::empty_mps(dst.sizes(), dst.scalar_type(), c10::nullopt, kMPS); + tmpBuffer = getMTLBufferStorage(tmp); } } @@ -181,15 +181,14 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, } // Copies tensor from cpu to mps backed by identical strided-contiguous data -static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking) -{ +static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking) { MPSStream* stream = getCurrentMPSStream(); id device = MPSDevice::getInstance()->device(); auto dst_byte_offset = dst.storage_offset() * dst.itemsize(); auto src_byte_offset = src.storage_offset() * src.itemsize(); id destBuffer = getMTLBufferStorage(dst); const size_t size_to_copy = src.nbytes(); - const void* host_src = static_cast(src.storage().data()) + src_byte_offset; + const void* host_src = static_cast(src.storage().data()) + src_byte_offset; TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src)); @@ -201,17 +200,16 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength); sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr - length:alignedLength - options:options - deallocator:nil]; + length:alignedLength + options:options + deallocator:nil]; stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); [sourceBuffer release]; } } -static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) -{ +static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) { // Typecast to dst_ if needed and expand, which is a no-op Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_); @@ -233,7 +231,7 @@ static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bo dst = at::empty_like(src, at::device(at::kMPS)); } copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy); - return needs_copy? dst_.copy_(dst) : dst_; + return needs_copy ? dst_.copy_(dst) : dst_; } void copy_blit_mps(void* dst, const void* src, size_t size) { @@ -241,8 +239,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { stream->copy_and_sync((id)(src), (id)(dst), size, 0, 0, true); } -static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) -{ +static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking) { auto src_byte_offset = src_.storage_offset() * src_.itemsize(); auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize(); @@ -250,7 +247,8 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // gather into dst. This reduces the overhead of doing an additional blit for most cases bool returnGatherOutput = dst_.is_contiguous(); Tensor src; - auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format()); + auto sameMemFormat = + src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format()); const bool sameDataType = src_.dtype() == dst_.dtype(); if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) || @@ -290,19 +288,18 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset); } else { if (dst_byte_offset) { - auto tmp = at::native::empty_mps(dst_.sizes(), dst_.scalar_type(), c10::nullopt, kMPS); - auto tmpBuffer = getMTLBufferStorage(tmp); - copy_cast_mps(tmp, src, tmpBuffer, sourceBuffer); - stream->copy(tmpBuffer, destBuffer, dst_.nbytes(), 0, dst_byte_offset); + auto tmp = at::native::empty_mps(dst_.sizes(), dst_.scalar_type(), c10::nullopt, kMPS); + auto tmpBuffer = getMTLBufferStorage(tmp); + copy_cast_mps(tmp, src, tmpBuffer, sourceBuffer); + stream->copy(tmpBuffer, destBuffer, dst_.nbytes(), 0, dst_byte_offset); } else { - copy_cast_mps(dst_, src, destBuffer, sourceBuffer); + copy_cast_mps(dst_, src, destBuffer, sourceBuffer); } } return dst_; } -at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking) -{ +at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking) { TORCH_CHECK(dst.defined(), "dst is undefined"); TORCH_CHECK(src.defined(), "src is undefined"); @@ -328,20 +325,16 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) { return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking); } - TORCH_INTERNAL_ASSERT( - src.device().type() == DeviceType::MPS, - "mps_copy_ is implemented only for *->MPS; MPS->*"); + TORCH_INTERNAL_ASSERT(src.device().type() == DeviceType::MPS, "mps_copy_ is implemented only for *->MPS; MPS->*"); return dst; } } // namespace mps -Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst) -{ +Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst) { return mps::mps_copy_(const_cast(dst), self, false); } -Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) -{ +Tensor _copy_from_mps(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { return mps::mps_copy_(const_cast(dst), self, non_blocking); } diff --git a/aten/src/ATen/native/mps/operations/CrossKernel.mm b/aten/src/ATen/native/mps/operations/CrossKernel.mm index f140bcb8f9b6d4..2cfe9ea3eb1da1 100644 --- a/aten/src/ATen/native/mps/operations/CrossKernel.mm +++ b/aten/src/ATen/native/mps/operations/CrossKernel.mm @@ -1,7 +1,7 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at::native { @@ -82,12 +82,12 @@ kernel void cross(constant void * input_ [[buffer(0)]], return crossLibrary; } - NSError *error = nil; - MTLCompileOptions *options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion: MTLLanguageVersion2_3]; - crossLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_CROSS encoding:NSASCIIStringEncoding] - options:options - error:&error]; + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + crossLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_CROSS encoding:NSASCIIStringEncoding] + options:options + error:&error]; TORCH_CHECK(crossLibrary, "Failed to create metal cross library, error: ", [[error description] UTF8String]); return crossLibrary; } @@ -115,25 +115,25 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other, TORCH_CHECK(input.dtype() != at::kDouble, "float64 is not supported on MPS"); auto iter = TensorIteratorConfig() - .add_output(out) - .add_input(input) - .add_input(other) - .resize_outputs(false) - .declare_static_shape(out.sizes(), /*squash_dims=*/dim) - .build(); - - id inputBuffer = getMTLBufferStorage(input); - id otherBuffer = getMTLBufferStorage(other); + .add_output(out) + .add_input(input) + .add_input(other) + .resize_outputs(false) + .declare_static_shape(out.sizes(), /*squash_dims=*/dim) + .build(); + + id inputBuffer = getMTLBufferStorage(input); + id otherBuffer = getMTLBufferStorage(other); id outputBuffer = getMTLBufferStorage(out); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); - const int64_t out_dim_stride = out.stride(dim); + const int64_t out_dim_stride = out.stride(dim); const int64_t input_dim_stride = input.stride(dim); const int64_t other_dim_stride = other.stride(dim); const uint32_t nDim = iter.ndim(); constexpr uint32_t nOffsets = 3; const uint32_t numThreads = iter.numel(); - dispatch_sync(mpsStream->queue(), ^(){ + dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { NSError* error = nil; id commandBuffer = mpsStream->commandBuffer(); @@ -143,23 +143,25 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other, std::vector iterShapeData(iterShape.size()); std::vector> strides(nDim); - for (const auto i: c10::irange(iterShape.size())) { + for (const auto i : c10::irange(iterShape.size())) { TORCH_CHECK(i <= UINT32_MAX); iterShapeData[i] = (uint32_t)(iterShape[i]); } - for (const auto i: c10::irange(nDim)) { - for (const auto offset: c10::irange(nOffsets)) { - strides[i][offset] = iter.strides(offset)[i]; + for (const auto i : c10::irange(nDim)) { + for (const auto offset : c10::irange(nOffsets)) { + strides[i][offset] = iter.strides(offset)[i]; } } - id kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); - id kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction - error: &error] autorelease]; - id kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3) - options: 0] autorelease]; - TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + id kernelDataOffsetsFunction = + MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); + id kernelDataOffsetsPSO = + [[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease]; + id kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3) + options:0] autorelease]; + TORCH_CHECK( + kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); [computeEncoder setComputePipelineState:kernelDataOffsetsPSO]; [computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0]; [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1]; @@ -169,30 +171,28 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other, NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup; if (kernelOffsetsTGSize > numThreads) - kernelOffsetsTGSize = numThreads; + kernelOffsetsTGSize = numThreads; MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1); - [computeEncoder dispatchThreads: gridSize - threadsPerThreadgroup: kernelOffsetsThreadGroupSize]; + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize]; id crossPSO = crossPipelineState(device, out.scalar_type()); [computeEncoder setComputePipelineState:crossPSO]; - [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; - [computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1]; + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; + [computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2]; [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3]; - [computeEncoder setBytes:&out_dim_stride length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&out_dim_stride length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&input_dim_stride length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&other_dim_stride length:sizeof(int64_t) atIndex:6]; NSUInteger tgSize = crossPSO.maxTotalThreadsPerThreadgroup; if (tgSize > numThreads) { - tgSize = numThreads; + tgSize = numThreads; } MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreads: gridSize - threadsPerThreadgroup: threadGroupSize]; + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; [computeEncoder endEncoding]; mpsStream->commit(true); diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index aed43a29949fec..9635b787891a99 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -1,39 +1,40 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include -#include #include +#include +#include #include +#include +#include namespace at::native { namespace mps { -struct RandomCachedGraph : public MPSCachedGraph -{ - RandomCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { } +struct RandomCachedGraph : public MPSCachedGraph { + RandomCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} // Only relevant for multinomial - MPSGraphTensor *probTensor = nil; - MPSGraphTensor *resultTensor = nil; - MPSGraphTensor *stateTensor = nil; + MPSGraphTensor* probTensor = nil; + MPSGraphTensor* resultTensor = nil; + MPSGraphTensor* stateTensor = nil; // used for Normal distributions only MPSGraphTensor *meanTensor = nil, *stdTensor = nil; }; typedef MPSGraphTensor* (^RandomOpBlock)(RandomCachedGraph*, MPSGraphTensor*); -#define RandomOpFn(graph, randomTensor) MPSGraphTensor* (mps::RandomCachedGraph* graph, MPSGraphTensor* randomTensor) +#define RandomOpFn(graph, randomTensor) MPSGraphTensor*(mps::RandomCachedGraph * graph, MPSGraphTensor * randomTensor) // for Uniform distributions with scalar from (val1) and to (val2) intervals // for Normal distributions with scalar mean (val1) and std (val2) values -template -Tensor& random_mps_impl(Tensor& self, scalar_t val1, scalar_t val2, +template +Tensor& random_mps_impl(Tensor& self, + scalar_t val1, + scalar_t val2, const c10::optional& mean_opt, const c10::optional& std_opt, MPSGraphRandomDistribution distribution, c10::optional gen, - std::string op_name, RandomOpBlock randomBlock) -{ + std::string op_name, + RandomOpBlock randomBlock) { if (self.numel() == 0) { return self; } @@ -46,13 +47,14 @@ auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - RandomCachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + RandomCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new RandomCachedGraph(mpsGraph); - newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(at::mps::detail::PHILOX_STATE_N)]); + newCachedGraph->stateTensor = + mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); // FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend. const MPSDataType inputDataType = [&] { @@ -64,8 +66,8 @@ }(); const MPSDataType outputDataType = (std::is_same::value) ? MPSDataTypeBool : inputDataType; - MPSGraphRandomOpDescriptor *desc = [MPSGraphRandomOpDescriptor descriptorWithDistribution: distribution - dataType: inputDataType]; + MPSGraphRandomOpDescriptor* desc = [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution + dataType:inputDataType]; if (distribution == MPSGraphRandomDistributionUniform) { if (inputDataType == MPSDataTypeInt32) { desc.minInteger = static_cast(val1); @@ -81,10 +83,10 @@ // we don't use the output state tensor from the MPSGraph API as it requires reading back from GPU to CPU. // Instead, we keep the Philox state in the MPSGenerator and use the PyTorch's philox_engine to maintain // the counters, and feed them to the graph manually - NSArray *resultTensors = [mpsGraph randomTensorWithShape: getMPSShape(self) - descriptor: desc - stateTensor: newCachedGraph->stateTensor - name: nil]; + NSArray* resultTensors = [mpsGraph randomTensorWithShape:getMPSShape(self) + descriptor:desc + stateTensor:newCachedGraph->stateTensor + name:nil]; newCachedGraph->resultTensor = randomBlock ? randomBlock(newCachedGraph, resultTensors[0]) : resultTensors[0]; // results will be cast if self's scalar type isn't directly supported by MPS backend. if (getMPSDataType(self) != outputDataType) @@ -94,19 +96,20 @@ }); } // feed the updated state values to the graph - MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]]; - MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease]; + MPSNDArrayDescriptor* stateDesc = + [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]]; + MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease]; { // See Note [Acquire lock when using random generators] std::lock_guard lock(mps_gen->mutex_); // update the Philox state values on each run mps_gen->update_philox_counters(); - [stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil]; + [stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil]; } - MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease]; + MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease]; Placeholder meanPlaceholder, stdPlaceholder; - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; feeds[cachedGraph->stateTensor] = stateTensorData; if (cachedGraph->stdTensor) { @@ -121,7 +124,7 @@ } Placeholder outputPlaceholder = Placeholder(cachedGraph->resultTensor, self); - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), }; @@ -131,13 +134,14 @@ return self; } -Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s, +Tensor& normal_mps_impl(Tensor& self, + double mean_s, + double std_s, const c10::optional& mean_opt, const c10::optional& std_opt, c10::optional gen, - std::string op_name) -{ - const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt)); + std::string op_name) { + const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt)); const Tensor& mean_t = *(at::borrow_from_optional_tensor(mean_opt)); TORCH_CHECK(std_s >= 0.0, op_name, " expects std >= 0.0, but found std=", std_s); @@ -153,39 +157,45 @@ if (std_t.defined()) { cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, std_t); - resultTensor = [mpsGraph multiplicationWithPrimaryTensor: randomTensor - secondaryTensor: cachedGraph->stdTensor - name: nil]; + resultTensor = [mpsGraph multiplicationWithPrimaryTensor:randomTensor + secondaryTensor:cachedGraph->stdTensor + name:nil]; } if (mean_t.defined()) { cachedGraph->meanTensor = mpsGraphRankedPlaceHolder(mpsGraph, mean_t); - return [mpsGraph additionWithPrimaryTensor: resultTensor - secondaryTensor: cachedGraph->meanTensor - name: nil]; + return [mpsGraph additionWithPrimaryTensor:resultTensor secondaryTensor:cachedGraph->meanTensor name:nil]; } return resultTensor; }; - return random_mps_impl(self, mean_s, std_s, mean_opt, std_opt, - MPSGraphRandomDistributionNormal, gen, - op_name + getTensorsStringKey({mean_t, std_t}), random_op_block); - + return random_mps_impl(self, + mean_s, + std_s, + mean_opt, + std_opt, + MPSGraphRandomDistributionNormal, + gen, + op_name + getTensorsStringKey({mean_t, std_t}), + random_op_block); } -Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional gen, std::string op_name) -{ +Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional gen, std::string op_name) { TORCH_CHECK(prob_t.is_same_size(self), op_name, ": probability and self tensor should be of the same shape") RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); cachedGraph->stdTensor = mpsGraphRankedPlaceHolder(mpsGraph, prob_t); - return [mpsGraph lessThanWithPrimaryTensor: randomTensor - secondaryTensor: cachedGraph->stdTensor - name: nil]; + return [mpsGraph lessThanWithPrimaryTensor:randomTensor secondaryTensor:cachedGraph->stdTensor name:nil]; }; // Bernoulli generates binary output so we use bool type - return mps::random_mps_impl(self, 0.0, 1.0, c10::nullopt, prob_t, - MPSGraphRandomDistributionUniform, gen, - op_name + getTensorsStringKey({prob_t}), random_op_block); + return mps::random_mps_impl(self, + 0.0, + 1.0, + c10::nullopt, + prob_t, + MPSGraphRandomDistributionUniform, + gen, + op_name + getTensorsStringKey({prob_t}), + random_op_block); } } // namespace mps @@ -196,15 +206,19 @@ const auto max = static_cast(std::numeric_limits::max()); TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to); TORCH_CHECK((to - from) <= std::numeric_limits::max(), - "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()), - ">::max(), but found to=", to, " and from=", from, - " which result in to-from to exceed the limit"); + "uniform_ expects to-from <= std::numeric_limits<", + toString(self.scalar_type()), + ">::max(), but found to=", + to, + " and from=", + from, + " which result in to-from to exceed the limit"); from = std::min(std::max(from, min), max); to = std::max(std::min(to, max), min); }); - return mps::random_mps_impl(self, from, to, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, gen, __func__, nullptr); + return mps::random_mps_impl( + self, from, to, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr); } Tensor& normal_mps_(Tensor& self, double mean, double std, c10::optional gen) { @@ -248,7 +262,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen, Tensor& result) { result.resize_(p_.sizes()); - return mps::bernoulli_mps_impl(result, p_, gen, __func__); + return mps::bernoulli_mps_impl(result, p_, gen, __func__); } Tensor& bernoulli_mps_(Tensor& self, double p, c10::optional gen) { @@ -271,22 +285,35 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional= to=", to); if (isFloatingType(input_dtype)) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] { - from = templates::update_from(from); - to = templates::update_to(to); - TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_update_from_to", [&] { + from = templates::update_from(from); + to = templates::update_to(to); + TORCH_CHECK( + from < to, + "random_mps_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", + from, + " >= to=", + to); + }); templates::check_from_to_in_range(from, to - 1, self.dtype()); } } else if (from != std::numeric_limits::lowest()) { // [from, std::numeric_limits::max()] if (isFloatingType(input_dtype)) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] { - constexpr int64_t scalar_t_max = static_cast(1) << std::numeric_limits::digits; - to = scalar_t_max > std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(scalar_t_max); - from = templates::update_from(from); - TORCH_CHECK(from < to, "random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=", from, " > to=", to); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, input_dtype, "random_from_to_range_calc", [&] { + constexpr int64_t scalar_t_max = static_cast(1) << std::numeric_limits::digits; + to = scalar_t_max > std::numeric_limits::max() ? std::numeric_limits::max() + : static_cast(scalar_t_max); + from = templates::update_from(from); + TORCH_CHECK( + from < to, + "random_mps_ expects 'from' casted to dtype to be less than or equal to 'to' casted to dtype, but got from=", + from, + " > to=", + to); + }); } else if (isIntegralType(input_dtype, /*includeBool=*/true)) { AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, input_dtype, "random_from_to_range_calc", [&] { if (std::is_same::value) { @@ -295,13 +322,11 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(std::numeric_limits::max()); } }); - } - else { + } else { TORCH_CHECK(false, "random_mps_ handles only integral, floating-point and boolean types"); } templates::check_from_to_in_range(from, to, self.dtype()); - } - else { + } else { // [std::numeric_limits::lowest(), std::numeric_limits::max()] // range = 2^64 @@ -309,8 +334,8 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional max() range"); } - return mps::random_mps_impl(self, from, to - 1, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, gen, __func__, nullptr); + return mps::random_mps_impl( + self, from, to - 1, c10::nullopt, c10::nullopt, MPSGraphRandomDistributionUniform, gen, __func__, nullptr); } Tensor& random_mps_(Tensor& self, int64_t to, c10::optional gen) { @@ -323,22 +348,23 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optionalgraph(); - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar: 1.0f - dataType: randomTensor.dataType]; - MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar: -lambda - dataType: randomTensor.dataType]; - MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor: unitTensor - secondaryTensor: randomTensor - name: nil]; - MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor: subtractTensor - name: nil]; - return [mpsGraph divisionWithPrimaryTensor: logTensor - secondaryTensor: minusLambdaTensor - name: nil]; + MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f dataType:randomTensor.dataType]; + MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar:-lambda dataType:randomTensor.dataType]; + MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor + secondaryTensor:randomTensor + name:nil]; + MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil]; + return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil]; }; - return mps::random_mps_impl(self, 0.0, 1.0, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, gen, - "exponential_mps_:" + std::to_string(lambda), random_op_block); + return mps::random_mps_impl(self, + 0.0, + 1.0, + c10::nullopt, + c10::nullopt, + MPSGraphRandomDistributionUniform, + gen, + "exponential_mps_:" + std::to_string(lambda), + random_op_block); } Tensor& randperm_out_mps(int64_t n, c10::optional generator, Tensor& result) { @@ -354,9 +380,12 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional= 0, "n must be non-negative, got", n); - TORCH_CHECK(!generator.has_value() || - (generator.has_value() && result.device() == generator->device()), - "Expected a '", result.device(), "' generator device but found '", generator->device(), "'"); + TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), + "Expected a '", + result.device(), + "' generator device but found '", + generator->device(), + "'"); check_supported_max_int_with_precision(n, result); result.resize_({n}); @@ -366,36 +395,34 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optionalgraph(); - MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor - axis:0 - name:nil]; + MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor axis:0 name:nil]; if (result.scalar_type() != kInt) { - argsortTensor = [mpsGraph castTensor:argsortTensor - toType:mps::getMPSDataType(result) - name:@"castOutput"]; + argsortTensor = [mpsGraph castTensor:argsortTensor toType:mps::getMPSDataType(result) name:@"castOutput"]; } return argsortTensor; }; - return mps::random_mps_impl(result, 0.0, 1.0, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, generator, - "ranperm_out_mps:" + mps::getTensorsStringKey({result}), random_op_block); + return mps::random_mps_impl(result, + 0.0, + 1.0, + c10::nullopt, + c10::nullopt, + MPSGraphRandomDistributionUniform, + generator, + "ranperm_out_mps:" + mps::getTensorsStringKey({result}), + random_op_block); } -Tensor& multinomial_with_replacement_mps_kernel( - const Tensor& self, - const int64_t n_sample, - c10::optional generator, - Tensor& result) { - +Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self, + const int64_t n_sample, + c10::optional generator, + Tensor& result) { using namespace mps; auto mps_gen = get_generator_or_default(generator, at::mps::detail::getDefaultMPSGenerator()); int inputSize = self.dim(); - int numDist = - inputSize == 1 ? 1 : self.size(0); - int numCategories = - inputSize == 1 ? self.size(0) : self.size(1); + int numDist = inputSize == 1 ? 1 : self.size(0); + int numCategories = inputSize == 1 ? self.size(0) : self.size(1); // Restructure data for 2d auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self; @@ -408,24 +435,22 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optionalLookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - RandomCachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + RandomCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSShape* prob_shape = getMPSShape(self_v); MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new RandomCachedGraph(mpsGraph); - newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]); + newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]); auto prob_dtype = getMPSDataType(self_v); // This is probability weights newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v), prob_shape); - MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor - axis:-1 - name:nil]; + MPSGraphTensor* sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor axis:-1 name:nil]; - MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor + MPSGraphTensor* normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor secondaryTensor:sumProbs name:nil]; @@ -433,139 +458,125 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1] + MPSGraphRandomOpDescriptor* descriptor = + [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform + dataType:prob_dtype]; + NSArray* generatorTensors = [mpsGraph randomTensorWithShape:@[ ns_numDist, ns_n_sample, @1 ] descriptor:descriptor stateTensor:newCachedGraph->stateTensor name:nil]; - MPSGraphTensor *randomTensor = generatorTensors[0]; + MPSGraphTensor* randomTensor = generatorTensors[0]; - auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories]; + auto broadcastShape = @[ ns_numDist, ns_n_sample, ns_numCategories ]; int broadcastShapeVals[3] = {numDist, static_cast(n_sample), numCategories}; - MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count] - shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]] - dataType:MPSDataTypeUInt32]; - - MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor - toShape:broadcastShape - name:nil]; - MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor + MPSGraphTensor* broadcastShapeTensor = [mpsGraph + constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count] + shape:@[ [NSNumber numberWithUnsignedInteger:broadcastShape.count] ] + dataType:MPSDataTypeUInt32]; + + MPSGraphTensor* samplesTensor = [mpsGraph broadcastTensor:randomTensor toShape:broadcastShape name:nil]; + MPSGraphTensor* sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor secondaryTensor:lowerProbRange name:nil]; - MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor + MPSGraphTensor* sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor secondaryTensor:upperProbRange name:nil]; - MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove - secondaryTensor:sampleBelow - name:nil]; - MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin - toType:MPSDataTypeInt32 - name:@"sampleMask"]; - MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1 + MPSGraphTensor* sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove + secondaryTensor:sampleBelow + name:nil]; + MPSGraphTensor* sampleMask = [mpsGraph castTensor:sampleWithin toType:MPSDataTypeInt32 name:@"sampleMask"]; + MPSGraphTensor* categoriesTensor = [mpsGraph coordinateAlongAxis:-1 withShapeTensor:broadcastShapeTensor name:nil]; - MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor - secondaryTensor:sampleMask - name:nil]; - MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor - axis:-1 - name:nil]; - MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor - withShape:@[ns_numDist ,ns_n_sample] - name:nil]; + MPSGraphTensor* binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor + secondaryTensor:sampleMask + name:nil]; + MPSGraphTensor* reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor axis:-1 name:nil]; + MPSGraphTensor* reshapeTensor = [mpsGraph reshapeTensor:reducedTensor + withShape:@[ ns_numDist, ns_n_sample ] + name:nil]; newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor toType:getMPSDataType(result) name:@"resultTensor"]; } return newCachedGraph; - }); + }); } // update the Philox state values on each run of the same graph - MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]]; - MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease]; + MPSNDArrayDescriptor* stateDesc = + [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]]; + MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease]; { // See Note [Acquire lock when using random generators] std::lock_guard lock(mps_gen->mutex_); // update the Philox state values on each run mps_gen->update_philox_counters(); - [stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil]; + [stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil]; } - MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease]; + MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease]; auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v); auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ cachedGraph->stateTensor : stateTensorData, probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return result; - } /* The largest consecutive integer representable in float32 (2^24) */ constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG); Tensor& multinomial_out_mps(const Tensor& self, - int64_t n_sample, - bool with_replacement, - c10::optional gen, - Tensor& result) { - + int64_t n_sample, + bool with_replacement, + c10::optional gen, + Tensor& result) { + TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device"); + TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim"); + TORCH_CHECK(at::isFloatingType(self.scalar_type()), + "multinomial only supports floating-point dtypes for input, got: ", + self.scalar_type()); TORCH_CHECK( - result.device() == self.device(), - "multinomial arguments must have the same device"); - TORCH_CHECK( - self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim"); - TORCH_CHECK( - at::isFloatingType(self.scalar_type()), - "multinomial only supports floating-point dtypes for input, got: ", - self.scalar_type()); - TORCH_CHECK(result.scalar_type() == ScalarType::Long, - "multinomial expects Long tensor out, got: ", result.scalar_type()); + result.scalar_type() == ScalarType::Long, "multinomial expects Long tensor out, got: ", result.scalar_type()); TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples"); int64_t n_categories = self.size(-1); TORCH_CHECK(with_replacement || (n_sample <= n_categories), - "cannot sample n_sample > prob_dist.size(-1) samples without replacement"); + "cannot sample n_sample > prob_dist.size(-1) samples without replacement"); // Since the index tensor is float, numCategories cannot exceed max // float integer precision - TORCH_CHECK( - n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, - "number of categories cannot exceed 2^24"); + TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24"); if (self.dim() == 1) { result.resize_({n_sample}); @@ -583,19 +594,15 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional= 0)).item(); - TORCH_CHECK( - is_valid.to(), - "probability tensor contains either `inf`, `nan` or element < 0"); + TORCH_CHECK(is_valid.to(), "probability tensor contains either `inf`, `nan` or element < 0"); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool zero_prob_condition; - if (self.dim() == 1){ + if (self.dim() == 1) { zero_prob_condition = (self.sum() == 0).item().to(); } else { zero_prob_condition = (self.sum(1) == 0).sum().item().to(); } - TORCH_CHECK( - !zero_prob_condition, - "invalid multinomial distribution (sum of probabilities <= 0)"); + TORCH_CHECK(!zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)"); // The algorithm is from gumbel softmax. // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) @@ -625,11 +632,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen) { +Tensor multinomial_mps(const Tensor& self, int64_t n_sample, bool with_replacement, c10::optional gen) { Tensor result = at::empty({0}, self.options().dtype(kLong)); multinomial_out_mps(self, n_sample, with_replacement, gen, result); return result; diff --git a/aten/src/ATen/native/mps/operations/Eye.mm b/aten/src/ATen/native/mps/operations/Eye.mm index e65b594306273c..69bbded33d19a0 100644 --- a/aten/src/ATen/native/mps/operations/Eye.mm +++ b/aten/src/ATen/native/mps/operations/Eye.mm @@ -3,9 +3,8 @@ #include #include #include -#include #include - +#include // Steps to add op for MPS backend: // 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key @@ -29,7 +28,6 @@ // g) Then call runMPSGraph() with input params and return the result. // - namespace at::native { Tensor& eye_out_mps(int64_t n, Tensor& result) { @@ -38,7 +36,6 @@ } Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) { - // This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); @@ -47,7 +44,7 @@ result.zero_(); // Handle empty outputs - if(result.numel() == 0) + if (result.numel() == 0) return result; // Get MPS stream @@ -55,25 +52,24 @@ MPSStream* stream = getCurrentMPSStream(); // Derive from MPSCachedGraph - // This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph time and time again for the same operation - // The keys of this structure are based on the inputs and outputs needed for the operation - // Here, we don't have any input tensors, just an output tensor - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + // This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph + // time and time again for the same operation The keys of this structure are based on the inputs and outputs needed + // for the operation Here, we don't have any input tensors, just an output tensor + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph + // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types + // etc match the earlier created MPSGraph string key = "eye_out_mps:" + getTensorsStringKey({result}); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { // Initialize graph @@ -84,11 +80,9 @@ dataType:getMPSDataType(result)]; // Here we can call the MPSGraph API needed to execute the operation. - // The API details can be found here: https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph - MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor - numLower:0 - numUpper:0 - name:nil]; + // The API details can be found here: + // https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph + MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil]; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -102,9 +96,8 @@ // In this case, there are no inputs, so the feeds are nil NSDictionary* feeds = nil; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; // Run the graph runMPSGraph(stream, cachedGraph->graph(), feeds, results); @@ -113,5 +106,4 @@ return result; } - } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 7bf2d5f471ed3a..026b4c817dabdb 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -1,12 +1,15 @@ -#include #include #include +#include namespace at { namespace native { -void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& grid, - int64_t interpolation_mode, int64_t padding_mode, +void grid_sampler_2d_mps_impl(Tensor& output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) { // Grid Sampler support has been added in macOS 13.1 #if defined(__MAC_13_2) @@ -18,35 +21,43 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& MPSGraphPaddingMode paddingMode; auto memory_format = input.suggest_memory_format(); - MPSGraphTensorNamedDataLayout inputTensorLayout = - (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC; + MPSGraphTensorNamedDataLayout inputTensorLayout = (memory_format == at::MemoryFormat::Contiguous) + ? MPSGraphTensorNamedDataLayoutNCHW + : MPSGraphTensorNamedDataLayoutNHWC; switch (static_cast(padding_mode)) { case GridSamplerPadding::Zeros: - paddingMode = MPSGraphPaddingModeZero; break; + paddingMode = MPSGraphPaddingModeZero; + break; case GridSamplerPadding::Border: - TORCH_CHECK(false, "MPS: Unsupported Border padding mode"); break; + TORCH_CHECK(false, "MPS: Unsupported Border padding mode"); + break; case GridSamplerPadding::Reflection: - paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric; break; + paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric; + break; default: TORCH_CHECK(false, "MPS: Unrecognised Padding Mode: ", padding_mode); } switch (static_cast(interpolation_mode)) { case GridSamplerInterpolation::Bilinear: - samplingMode = MPSGraphResizeBilinear; break; + samplingMode = MPSGraphResizeBilinear; + break; case GridSamplerInterpolation::Nearest: - samplingMode = MPSGraphResizeNearest; break; + samplingMode = MPSGraphResizeNearest; + break; case GridSamplerInterpolation::Bicubic: - TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation"); break; + TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation"); + break; default: - TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode); break; - } + TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode); + break; + } - MPSStream *stream = getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* gridTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; @@ -55,17 +66,13 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = "grid_sampler_2d_mps" + - getTensorsStringKey({input, grid}) + - ":" + std::to_string(interpolation_mode) + - ":" + std::to_string(padding_mode) + - ":" + std::to_string(align_corners); + string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) + + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -75,27 +82,27 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& MPSGraphTensor* outputTensor = nil; if (static_cast(interpolation_mode) == GridSamplerInterpolation::Nearest) { - outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor - coordinateTensor: gridTensor - layout: inputTensorLayout - normalizeCoordinates: TRUE - relativeCoordinates: FALSE - alignCorners: align_corners - paddingMode: paddingMode - nearestRoundingMode: MPSGraphResizeNearestRoundingModeRoundToEven - constantValue: 0.0f - name: nil]; + outputTensor = [mpsGraph sampleGridWithSourceTensor:inputTensor + coordinateTensor:gridTensor + layout:inputTensorLayout + normalizeCoordinates:TRUE + relativeCoordinates:FALSE + alignCorners:align_corners + paddingMode:paddingMode + nearestRoundingMode:MPSGraphResizeNearestRoundingModeRoundToEven + constantValue:0.0f + name:nil]; } else { - outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor - coordinateTensor: gridTensor - layout: inputTensorLayout - normalizeCoordinates: TRUE - relativeCoordinates: FALSE - alignCorners: align_corners - paddingMode: paddingMode - samplingMode: samplingMode - constantValue: 0.0f - name: nil]; + outputTensor = [mpsGraph sampleGridWithSourceTensor:inputTensor + coordinateTensor:gridTensor + layout:inputTensorLayout + normalizeCoordinates:TRUE + relativeCoordinates:FALSE + alignCorners:align_corners + paddingMode:paddingMode + samplingMode:samplingMode + constantValue:0.0f + name:nil]; } newCachedGraph->inputTensor_ = inputTensor; @@ -104,29 +111,29 @@ void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); Placeholder gridPlaceholder = Placeholder(cachedGraph->gridTensor_, grid); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), gridPlaceholder.getMPSGraphTensor() : gridPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } #endif // defined(__MAC_13_2) } -Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid, - int64_t interpolation_mode, int64_t padding_mode, +Tensor grid_sampler_2d_mps(const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, bool align_corners) { #if defined(__MAC_13_2) bool xcode_sdk_13_2_or_higher = true; @@ -138,17 +145,16 @@ Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid, TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.1. ", "Falling back on CPU. This may have performance implications."); - return at::grid_sampler_2d( - input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners).clone().to("mps"); + return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners) + .clone() + .to("mps"); } auto in_size = input.sizes(); auto grid_size = grid.sizes(); - auto output = at::empty( - {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); + auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); - grid_sampler_2d_mps_impl( - output, input, grid, interpolation_mode, padding_mode, align_corners); + grid_sampler_2d_mps_impl(output, input, grid, interpolation_mode, padding_mode, align_corners); return output; } diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index ee5787a7f23e19..93066a158c0fea 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -3,25 +3,25 @@ #include #include -#include -#include +#include #include #include -#include +#include #include +#include +#include +#include +#include #include -#include -#include -#include #include -#include -#include #include -#include -#include +#include +#include +#include #include #include -#include +#include +#include #ifdef __OBJC__ #include @@ -29,26 +29,25 @@ namespace at::native { -static -bool dispatchIndexKernel(TensorIteratorBase& iter, - IntArrayRef index_size, - IntArrayRef index_stride, - bool index_select, - bool accumulate) { +static bool dispatchIndexKernel(TensorIteratorBase& iter, + IntArrayRef index_size, + IntArrayRef index_stride, + bool index_select, + bool accumulate) { using namespace mps; - if (iter.numel() == 0) { + if (iter.numel() == 0) { return true; } const Tensor& inputTensor = iter.tensor(1); Tensor outputTensor = iter.tensor(0); - id inputBuffer = getMTLBufferStorage(inputTensor); + id inputBuffer = getMTLBufferStorage(inputTensor); id outputBuffer = getMTLBufferStorage(outputTensor); MPSStream* mpsStream = getCurrentMPSStream(); id device = MPSDevice::getInstance()->device(); - dispatch_sync(mpsStream->queue(), ^(){ + dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { NSError* error = nil; constexpr uint32_t nOffsets = 3; @@ -59,13 +58,13 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, std::vector iterShapeData(iterShape.size()); std::vector> strides(nDim); - for (const auto i: c10::irange(iterShape.size())) { + for (const auto i : c10::irange(iterShape.size())) { TORCH_CHECK(i <= UINT32_MAX); iterShapeData[i] = (uint32_t)(iterShape[i]); } - for (const auto i: c10::irange(nDim)) { - for (const auto offset: c10::irange(nOffsets)) { + for (const auto i : c10::irange(nDim)) { + for (const auto offset : c10::irange(nOffsets)) { strides[i][offset] = iter.strides(offset)[i]; } } @@ -73,12 +72,14 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); id commandBuffer = mpsStream->commandBuffer(); id computeEncoder = [commandBuffer computeCommandEncoder]; - id kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); - id kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction - error: &error] autorelease]; - id kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3) - options: 0] autorelease]; - TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + id kernelDataOffsetsFunction = + MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); + id kernelDataOffsetsPSO = + [[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease]; + id kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3) + options:0] autorelease]; + TORCH_CHECK( + kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); [computeEncoder setComputePipelineState:kernelDataOffsetsPSO]; [computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0]; @@ -89,37 +90,37 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup; if (kernelOffsetsTGSize > numThreads) - kernelOffsetsTGSize = numThreads; + kernelOffsetsTGSize = numThreads; MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1); - [computeEncoder dispatchThreads: gridSize - threadsPerThreadgroup: kernelOffsetsThreadGroupSize]; + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize]; MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues new] autorelease]; - [constantValues setConstantValue: &num_indices type:MTLDataTypeUInt atIndex:0]; + [constantValues setConstantValue:&num_indices type:MTLDataTypeUInt atIndex:0]; std::string indexFunction = getIndexFunctionName(inputTensor.scalar_type(), index_select, accumulate); - id indexKernelFunction = MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues); + id indexKernelFunction = + MPSDevice::getInstance()->metalIndexingFunction(indexFunction, constantValues); id argumentEncoder = [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease]; NSUInteger argumentBufferLength = argumentEncoder.encodedLength; id indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease]; [argumentEncoder setArgumentBuffer:indexAB offset:0]; for (uint32_t idx = 0; idx < num_indices; idx++) { - const Tensor& indexTensor = iter.tensor(idx+2); - [argumentEncoder setBuffer: getMTLBufferStorage(indexTensor) - offset: indexTensor.storage_offset() * indexTensor.element_size() - atIndex: idx]; + const Tensor& indexTensor = iter.tensor(idx + 2); + [argumentEncoder setBuffer:getMTLBufferStorage(indexTensor) + offset:indexTensor.storage_offset() * indexTensor.element_size() + atIndex:idx]; TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index"); } // FIXME: PSO needs to be cached - id indexSelectPSO = [[device newComputePipelineStateWithFunction: indexKernelFunction - error: &error] autorelease]; + id indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction + error:&error] autorelease]; TORCH_CHECK(indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); for (uint32_t idx = 0; idx < num_indices; idx++) { - const Tensor& indexTensor = iter.tensor(idx+2); + const Tensor& indexTensor = iter.tensor(idx + 2); [computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead]; } @@ -129,15 +130,16 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, [computeEncoder setBytes:index_stride.data() length:sizeof(index_stride[0]) * index_stride.size() atIndex:2]; [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3]; [computeEncoder setBuffer:inputBuffer offset:inputTensor.storage_offset() * inputTensor.element_size() atIndex:4]; - [computeEncoder setBuffer:outputBuffer offset:outputTensor.storage_offset() * outputTensor.element_size() atIndex:5]; + [computeEncoder setBuffer:outputBuffer + offset:outputTensor.storage_offset() * outputTensor.element_size() + atIndex:5]; NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup; if (tgSize > numThreads) - tgSize = numThreads; + tgSize = numThreads; MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreads: gridSize - threadsPerThreadgroup: threadGroupSize]; + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; [computeEncoder endEncoding]; mpsStream->synchronize(SyncType::COMMIT_AND_CONTINUE); @@ -147,7 +149,11 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, return true; } -static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const std::string& op, bool accumulate) { +static void validateInputData(const TensorIteratorBase& iter, + IntArrayRef index_size, + IntArrayRef index_stride, + const std::string& op, + bool accumulate) { using namespace mps; int64_t num_indices = index_size.size(); @@ -159,13 +165,11 @@ static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_ if (accumulate) { // No atomic support for the rest of dtypes - TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float || - inputTensor.scalar_type() == ScalarType::Int || + TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int || inputTensor.scalar_type() == ScalarType::Bool); } else { TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true) || - inputTensor.scalar_type() == ScalarType::Float || - inputTensor.scalar_type() == ScalarType::Half, + inputTensor.scalar_type() == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Half, getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out")); } } @@ -186,46 +190,42 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray } } -static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) { +static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) { NoNamesGuard guard; - TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, - "masked_select: expected BoolTensor for mask"); + TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "masked_select: expected BoolTensor for mask"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "masked_select(): self and result must have the same scalar type"); - auto mask_temp = (mask.dim() == 0) - ? c10::MaybeOwned::owned(mask.unsqueeze(0)) - : c10::MaybeOwned::borrowed(mask); - auto self_temp = (self.dim() == 0) - ? c10::MaybeOwned::owned(self.unsqueeze(0)) - : c10::MaybeOwned::borrowed(self); + auto mask_temp = + (mask.dim() == 0) ? c10::MaybeOwned::owned(mask.unsqueeze(0)) : c10::MaybeOwned::borrowed(mask); + auto self_temp = + (self.dim() == 0) ? c10::MaybeOwned::owned(self.unsqueeze(0)) : c10::MaybeOwned::borrowed(self); // Cannot reassign to mask_temp and self_temp here! if they are // owning and expand_outplace returns a borrow, the returned borrow // would dangle. auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp); - at::index_out( - result, *std::get<1>(mask_self_expanded), - c10::List>({*std::move(std::get<0>(mask_self_expanded))})); + at::index_out(result, + *std::get<1>(mask_self_expanded), + c10::List>({*std::move(std::get<0>(mask_self_expanded))})); return result; } -static -Tensor nonzero_fallback(const Tensor& self) { +static Tensor nonzero_fallback(const Tensor& self) { TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ", "Falling back on CPU. This may have performance implications."); return at::nonzero(self.to("cpu")).clone().to("mps"); } -Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){ +Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) { if (!is_macos_13_or_newer()) { - Tensor out_fallback = nonzero_fallback(self); - at::native::resize_output(out_, out_fallback.sizes()); - out_.copy_(out_fallback.to("mps")); - return out_; + Tensor out_fallback = nonzero_fallback(self); + at::native::resize_output(out_, out_fallback.sizes()); + out_.copy_(out_fallback.to("mps")); + return out_; } int64_t nDim = self.dim(); @@ -237,18 +237,22 @@ Tensor nonzero_fallback(const Tensor& self) { using namespace mps; const uint32_t maxDimensions = 16; - TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ + TORCH_CHECK(self.numel() < std::numeric_limits::max(), + "nonzero is not supported for tensors with more than INT_MAX elements, \ file a support request"); - TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype()); - TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ", - out_.device(), " and self on ", self.device()); + TORCH_CHECK( + out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype()); + TORCH_CHECK(self.device() == out_.device(), + "expected self and out to be on the same device, but got out on ", + out_.device(), + " and self on ", + self.device()); TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions"); TORCH_CHECK(out_.is_mps()); - MPSStream *stream = getCurrentMPSStream(); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSStream* stream = getCurrentMPSStream(); + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; MPSGraphTensor* scatterDataTensor_ = nil; @@ -257,20 +261,15 @@ Tensor nonzero_fallback(const Tensor& self) { stream->synchronize(SyncType::COMMIT_AND_WAIT); Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt)); - Tensor out = at::native::empty_mps( - {self.numel(), nDim == 0 ? 1 : nDim}, - out_.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + Tensor out = at::native::empty_mps( + {self.numel(), nDim == 0 ? 1 : nDim}, out_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); int64_t _apparentInputShape = 1; for (auto dim : self.sizes()) { _apparentInputShape *= dim; } - MPSShape *apparentOutputShape = @[@(self.numel() * nDim)]; - MPSShape *apparentInputShape = @[@(_apparentInputShape)]; + MPSShape* apparentOutputShape = @[ @(self.numel() * nDim) ]; + MPSShape* apparentInputShape = @[ @(_apparentInputShape) ]; // Pseudocode: // @@ -284,67 +283,68 @@ Tensor nonzero_fallback(const Tensor& self) { MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = "nonzero_out_mps" + getTensorsStringKey(self); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSDataType inputDataType = getMPSDataType(self); MPSShape* inputShape = getMPSShape(self); MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape); - MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type())); - MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType]; - MPSGraphTensor *oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32]; - MPSGraphTensor *minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32]; - MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor + MPSGraphTensor* inputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape); + MPSGraphTensor* scatterDataTensor = + mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type())); + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType]; + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32]; + MPSGraphTensor* minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32]; + MPSGraphTensor* inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil]; - MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor - axis:0 - name:nil]; - MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor + MPSGraphTensor* countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor axis:0 name:nil]; + MPSGraphTensor* maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor toType:MPSDataTypeInt32 name:@"castToInt32"]; - MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor - axis:0 - name:nil]; - MPSGraphTensor *indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor - secondaryTensor:oneTensor - name:nil]; - MPSGraphTensor *maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor + MPSGraphTensor* indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor axis:0 name:nil]; + MPSGraphTensor* indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor + secondaryTensor:oneTensor + name:nil]; + MPSGraphTensor* maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor truePredicateTensor:indicesMinusOneTensor falsePredicateTensor:minusMaxDimTensor name:nil]; - MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil] - withShape:@[@-1] - name:nil]; + MPSGraphTensor* coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 + withShape:inputShape + name:nil] + withShape:@[ @-1 ] + name:nil]; if (nDim > 1) { - NSMutableArray *maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim]; - NSMutableArray *coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim]; + NSMutableArray* maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim]; + NSMutableArray* coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim]; - MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim - dataType:MPSDataTypeInt32]; + MPSGraphTensor* constantRankTensor = [mpsGraph constantWithScalar:nDim dataType:MPSDataTypeInt32]; maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor secondaryTensor:constantRankTensor name:nil]; coordinatesTensorArray[0] = coordinatesTensor; - for (int i = 1; i < nDim; i++){ + for (int i = 1; i < nDim; i++) { maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1] secondaryTensor:oneTensor name:nil]; - coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil] - withShape:@[@-1] + coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i + withShape:inputShape + name:nil] + withShape:@[ @-1 ] name:nil]; } maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil]; coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil]; } - MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor + MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor updatesTensor:coordinatesTensor indicesTensor:maskedIndicesTensor axis:0 @@ -358,7 +358,7 @@ Tensor nonzero_fallback(const Tensor& self) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape); @@ -386,7 +386,7 @@ Tensor nonzero_fallback(const Tensor& self) { return out_; } -Tensor nonzero_mps(const Tensor& self){ +Tensor nonzero_mps(const Tensor& self) { if (!is_macos_13_or_newer()) { return nonzero_fallback(self); } @@ -395,13 +395,13 @@ Tensor nonzero_mps(const Tensor& self){ return nonzero_out_mps(self, out); } -Tensor masked_select_mps(const Tensor & self, const Tensor & mask) { +Tensor masked_select_mps(const Tensor& self, const Tensor& mask) { namedinference::compute_broadcast_outnames(self, mask); Tensor result = at::empty({0}, self.options()); return masked_select_out_mps_impl(result, self, mask); } -Tensor & masked_select_out_mps(const Tensor & self, const Tensor & mask, Tensor & result) { +Tensor& masked_select_out_mps(const Tensor& self, const Tensor& mask, Tensor& result) { namedinference::compute_broadcast_outnames(self, mask); return masked_select_out_mps_impl(result, self, mask); } @@ -409,27 +409,22 @@ Tensor masked_select_mps(const Tensor & self, const Tensor & mask) { Tensor flip_mps(const Tensor& self, IntArrayRef dims) { using namespace mps; - Tensor result = at::native::empty_mps( - self.sizes(), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + Tensor result = + at::native::empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); auto total_dims = self.dim(); // It wraps the dims and checks that there are no repeated dims auto flip_dims_b = at::dim_list_to_bitset(dims, total_dims); - NSMutableArray * ns_dims = [[NSMutableArray new] autorelease]; + NSMutableArray* ns_dims = [[NSMutableArray new] autorelease]; for (const auto i : c10::irange(total_dims)) { - if(flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) { + if (flip_dims_b[i] && self.size(i) > 1 && self.stride(i) != 0) { [ns_dims addObject:[NSNumber numberWithInt:i]]; } } // Nothing to do, we return fast - if (dims.size() == 0 || self.numel() <=1) { + if (dims.size() == 0 || self.numel() <= 1) { result.copy_(self); return result; } @@ -442,31 +437,29 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSDataType inputDataType = getMPSScalarType(self.scalar_type()); MPSDataType outputDataType = getMPSScalarType(self.scalar_type()); if (!is_macos_13_or_newer()) { - if (self.scalar_type() == kBool) { + if (self.scalar_type() == kBool) { inputDataType = MPSDataTypeInt8; - } - if (result.scalar_type() == kBool) { + } + if (result.scalar_type() == kBool) { outputDataType = MPSDataTypeInt8; - } + } } @autoreleasepool { NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","]; - // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph + // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types + // etc match the earlier created MPSGraph string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]); auto cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self)); - MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor - axes:ns_dims - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor axes:ns_dims name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } @@ -475,36 +468,31 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { } // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation - Placeholder inputPlaceholder = Placeholder( - cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType); - Placeholder outputPlaceholder = Placeholder( - cachedGraph->outputTensor_, result, /*mpsShape*/nil, /*gatherTensorData=*/false, outputDataType); + Placeholder inputPlaceholder = + Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, result, /*mpsShape*/ nil, /*gatherTensorData=*/false, outputDataType); + NSDictionary* feeds = + @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()}; - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; // Run the graph runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return result; - } -TORCH_IMPL_FUNC(index_add_mps_out)( - const Tensor& self, - int64_t dim, - const Tensor& index, - const Tensor& source, - const Scalar& alpha, - const Tensor& result) { - +TORCH_IMPL_FUNC(index_add_mps_out) +(const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + const Scalar& alpha, + const Tensor& result) { using namespace mps; MPSStream* stream = getCurrentMPSStream(); dim = maybe_wrap_dim(dim, self.dim()); @@ -515,9 +503,8 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { TORCH_CHECK(source.scalar_type() != ScalarType::Long, "index_add(): Expected non int64 dtype for source."); auto casted_type = isFloatingType(source.scalar_type()) ? ScalarType::Float : ScalarType::Int; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* indexTensor_ = nil; MPSGraphTensor* sourceTensor_ = nil; @@ -528,13 +515,12 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -547,21 +533,21 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSGraphTensor* castedInputTensor = inputTensor; MPSGraphTensor* castedSourceTensor = sourceTensor; if (source.scalar_type() != casted_type) { - castedInputTensor = castMPSTensor(mpsGraph, castedInputTensor, casted_type); - castedSourceTensor = castMPSTensor(mpsGraph, castedSourceTensor, casted_type); + castedInputTensor = castMPSTensor(mpsGraph, castedInputTensor, casted_type); + castedSourceTensor = castMPSTensor(mpsGraph, castedSourceTensor, casted_type); } MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor:castedSourceTensor secondaryTensor:alphaTensor name:nil]; MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:castedInputTensor - updatesTensor:alphaSourceSlice - indicesTensor:indexTensor - axis:dim - mode:MPSGraphScatterModeAdd - name:nil]; + updatesTensor:alphaSourceSlice + indicesTensor:indexTensor + axis:dim + mode:MPSGraphScatterModeAdd + name:nil]; if (source.scalar_type() != casted_type) { - outputTensor = castMPSTensor(mpsGraph, outputTensor, source.scalar_type()); + outputTensor = castMPSTensor(mpsGraph, outputTensor, source.scalar_type()); } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->indexTensor_ = indexTensor; @@ -585,17 +571,14 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { sourcePlaceholder.getMPSGraphTensor() : sourcePlaceholder.getMPSGraphTensorData(), cachedGraph->alphaTensor_ : getMPSGraphTensorFromScalar(stream, alpha_scalar), }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -Tensor index_select_mps(const Tensor & self, - int64_t dim, - const Tensor & index) { +Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) { IntArrayRef input_shape = self.sizes(); auto num_input_dims = input_shape.size(); @@ -606,7 +589,7 @@ Tensor index_select_mps(const Tensor & self, std::vector shape_data(num_input_dims); // Calculate new shape - for(auto i : c10::irange(num_input_dims)) { + for (auto i : c10::irange(num_input_dims)) { if (i == dim) { shape_data[i] = num_indices; } else { @@ -616,33 +599,24 @@ Tensor index_select_mps(const Tensor & self, IntArrayRef output_shape = IntArrayRef(shape_data.data(), num_input_dims); - Tensor result = at::native::empty_mps( - output_shape, - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + Tensor result = + at::native::empty_mps(output_shape, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); index_select_out_mps(self, dim, index, result); return result; } -Tensor& index_select_out_mps(const Tensor & self, - int64_t dim, - const Tensor & index, - Tensor & output) { - +Tensor& index_select_out_mps(const Tensor& self, int64_t dim, const Tensor& index, Tensor& output) { using namespace mps; MPSStream* stream = getCurrentMPSStream(); dim = maybe_wrap_dim(dim, self.dim()); // Checks TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector"); - TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); + TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, + "index_select(): Expected dtype int32 or int64 for index"); TORCH_CHECK(self.scalar_type() == output.scalar_type(), "index_select(): self and output must have the same scalar type"); - TORCH_CHECK(dim == 0 || dim < self.dim(), - "index_select(): Indexing dim ", dim, " is out of bounds of tensor"); + TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor"); // Empty index if (index.numel() == 0) { @@ -650,15 +624,14 @@ Tensor index_select_mps(const Tensor & self, } // Scalar input - if (self.dim() == 0 && self.numel() == 1){ + if (self.dim() == 0 && self.numel() == 1) { output.copy_(self); return output; } // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* indexTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; @@ -667,23 +640,20 @@ Tensor index_select_mps(const Tensor & self, MPSGraphCache* cache_ = MPSGraphCache::getInstance(); auto inputType = getMPSDataType(self); auto outputType = getMPSDataType(output); - if (inputType == MPSDataTypeUInt8 || - (!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) { + if (inputType == MPSDataTypeUInt8 || (!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) { inputType = MPSDataTypeInt8; } - if (outputType == MPSDataTypeUInt8 || - (!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) { + if (outputType == MPSDataTypeUInt8 || (!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) { outputType = MPSDataTypeInt8; } @autoreleasepool { - string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -706,48 +676,55 @@ Tensor index_select_mps(const Tensor & self, }); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, - /*mpsShape=*/nullptr, /*gatherTensorData=*/true, /*dataType=*/inputType); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, + self, + /*mpsShape=*/nullptr, + /*gatherTensorData=*/true, + /*dataType=*/inputType); Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, - /*mpsShape=*/nullptr, /*gatherTensorData=*/false, /*dataType=*/outputType); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, + output, + /*mpsShape=*/nullptr, + /*gatherTensorData=*/false, + /*dataType=*/outputType); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return output; - } -Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value) { +Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Scalar& value) { using namespace mps; if (self.numel() == 0) { return self; } - TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ", - mask.device(), " and self on ", self.device()); + TORCH_CHECK(self.device() == mask.device(), + "expected self and mask to be on the same device, but got mask on ", + mask.device(), + " and self on ", + self.device()); TORCH_CHECK(mask.scalar_type() == kByte || mask.scalar_type() == kBool, - "expected mask dtype to be Bool but got ", mask.scalar_type()); + "expected mask dtype to be Bool but got ", + mask.scalar_type()); auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_"); c10::MaybeOwned b_mask = expand_inplace(self, mask, "masked_fill_"); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *maskTensor_ = nil; - MPSGraphTensor *valueTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* maskTensor_ = nil; + MPSGraphTensor* valueTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -757,12 +734,12 @@ Tensor index_select_mps(const Tensor & self, // Workaround for `selectWithPredicateTensor` on macOS Monterey where bool data type may cause a hang // The issue is fixed in macOS Ventura (13.0) if (!is_macos_13_or_newer()) { - if (self.scalar_type() == kBool) { + if (self.scalar_type() == kBool) { inputDataType = MPSDataTypeInt8; - } - if (mask.scalar_type() == kBool) { + } + if (mask.scalar_type() == kBool) { maskDataType = MPSDataTypeInt8; - } + } } MPSStream* stream = getCurrentMPSStream(); @@ -770,10 +747,9 @@ Tensor index_select_mps(const Tensor & self, @autoreleasepool { string key = "masked_fill" + getTensorsStringKey({self, *b_mask}) + ":" + getMPSTypeString(value.type()); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -786,15 +762,13 @@ Tensor index_select_mps(const Tensor & self, MPSDataType valueType = getMPSScalarType(value.type()); MPSGraphTensor* castValueTensor = valueTensor; if (valueType != inputDataType) { - castValueTensor = [mpsGraph castTensor:valueTensor - toType:inputDataType - name:@"castValueTensor"]; + castValueTensor = [mpsGraph castTensor:valueTensor toType:inputDataType name:@"castValueTensor"]; } MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor - truePredicateTensor:castValueTensor + truePredicateTensor:castValueTensor falsePredicateTensor:inputTensor - name:nil]; + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->maskTensor_ = maskTensor; @@ -805,12 +779,12 @@ Tensor index_select_mps(const Tensor & self, }); } - Placeholder selfPlaceholder = Placeholder( - cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType); - Placeholder maskPlaceholder = Placeholder( - cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nil, /*gatherTensorData=*/true, maskDataType); - Placeholder outputPlaceholder = Placeholder( - cachedGraph->outputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/false, inputDataType); + Placeholder selfPlaceholder = + Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType); + Placeholder maskPlaceholder = + Placeholder(cachedGraph->maskTensor_, *b_mask, /*mpsShape*/ nil, /*gatherTensorData=*/true, maskDataType); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/false, inputDataType); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ @@ -819,9 +793,8 @@ Tensor index_select_mps(const Tensor & self, cachedGraph->valueTensor_ : getMPSGraphTensorFromScalar(stream, valueScalar) }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -829,127 +802,122 @@ Tensor index_select_mps(const Tensor & self, return self; } -Tensor embedding_dense_backward_mps( - const Tensor & grad_, const Tensor & indices, int64_t num_weights, - int64_t padding_idx, bool scale_grad_by_freq) -{ - // TODO: implement padding_idx & scale_grad_by_freq. - namespace native_mps = at::native::mps; - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *incomingGradTensor_ = nil; - MPSGraphTensor *indicesTensor_ = nil; - MPSGraphTensor *outgoingGradTensor_ = nil; - }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - IntArrayRef incoming_gradient_shape = grad_.sizes(); - int64_t num_incoming_gradient_dims = incoming_gradient_shape.size(); +Tensor embedding_dense_backward_mps(const Tensor& grad_, + const Tensor& indices, + int64_t num_weights, + int64_t padding_idx, + bool scale_grad_by_freq) { + // TODO: implement padding_idx & scale_grad_by_freq. + namespace native_mps = at::native::mps; + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* incomingGradTensor_ = nil; + MPSGraphTensor* indicesTensor_ = nil; + MPSGraphTensor* outgoingGradTensor_ = nil; + }; - IntArrayRef indices_shape = indices.sizes(); - int64_t num_indices_dims = indices_shape.size(); + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - int64_t D = incoming_gradient_shape[num_incoming_gradient_dims - 1]; - c10::SmallVector outgoing_gradient_shape{num_weights, D}; - Tensor outgoing_gradient = at::native::empty_mps( - IntArrayRef(outgoing_gradient_shape), - grad_.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef incoming_gradient_shape = grad_.sizes(); + int64_t num_incoming_gradient_dims = incoming_gradient_shape.size(); - if (outgoing_gradient.numel() == 0) { - return outgoing_gradient; - } + IntArrayRef indices_shape = indices.sizes(); + int64_t num_indices_dims = indices_shape.size(); - auto stream = at::mps::getCurrentMPSStream(); + int64_t D = incoming_gradient_shape[num_incoming_gradient_dims - 1]; + c10::SmallVector outgoing_gradient_shape{num_weights, D}; + Tensor outgoing_gradient = at::native::empty_mps( + IntArrayRef(outgoing_gradient_shape), grad_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); - @autoreleasepool { - string key = "edb_mps:" + native_mps::getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) + ":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq); - CachedGraph* cachedGraph = cache_->LookUpAs(key); - // Initialize once if configuration not found in cache - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { + if (outgoing_gradient.numel() == 0) { + return outgoing_gradient; + } - CachedGraph *newCachedGraph = nil; + auto stream = at::mps::getCurrentMPSStream(); - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); + @autoreleasepool { + string key = "edb_mps:" + native_mps::getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) + + ":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" + + std::to_string(scale_grad_by_freq); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + // Initialize once if configuration not found in cache + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^native_mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - MPSGraphTensor* incomingGradTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_)); + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* indicesTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices)); + MPSGraphTensor* incomingGradTensor = + native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_)); - MPSGraphTensor* reshapedIndicesTensor = indicesTensor; + MPSGraphTensor* indicesTensor = + native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices)); - MPSGraphTensor* castGradTensor = incomingGradTensor; - MPSDataType dataType = mps::getMPSDataType(grad_); - // issue 105486100, scatterNDWithUpdatesTensor produces wrong result for float16 - if (dataType == MPSDataTypeFloat16) { - castGradTensor = [mpsGraph castTensor: incomingGradTensor - toType: MPSDataTypeFloat32 - name: @"castGradTensor"]; - } - if (num_indices_dims != 0) { - reshapedIndicesTensor = [mpsGraph expandDimsOfTensor: indicesTensor - axes: @[@-1] - name: nil]; - } + MPSGraphTensor* reshapedIndicesTensor = indicesTensor; - auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: castGradTensor - indicesTensor: reshapedIndicesTensor - shape: native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape)) - batchDimensions: 0 - mode: MPSGraphScatterModeAdd - name: @"edb"]; - if (dataType == MPSDataTypeFloat16) { - outgoingGradTensor = [mpsGraph castTensor: outgoingGradTensor - toType: MPSDataTypeFloat16 - name: @"castGradTensor"]; - } - newCachedGraph->incomingGradTensor_ = incomingGradTensor; - newCachedGraph->indicesTensor_ = indicesTensor; - newCachedGraph->outgoingGradTensor_ = outgoingGradTensor; + MPSGraphTensor* castGradTensor = incomingGradTensor; + MPSDataType dataType = mps::getMPSDataType(grad_); + // issue 105486100, scatterNDWithUpdatesTensor produces wrong result for float16 + if (dataType == MPSDataTypeFloat16) { + castGradTensor = [mpsGraph castTensor:incomingGradTensor toType:MPSDataTypeFloat32 name:@"castGradTensor"]; + } + if (num_indices_dims != 0) { + reshapedIndicesTensor = [mpsGraph expandDimsOfTensor:indicesTensor axes:@[ @-1 ] name:nil]; + } + auto outgoingGradTensor = + [mpsGraph scatterNDWithUpdatesTensor:castGradTensor + indicesTensor:reshapedIndicesTensor + shape:native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape)) + batchDimensions:0 + mode:MPSGraphScatterModeAdd + name:@"edb"]; + if (dataType == MPSDataTypeFloat16) { + outgoingGradTensor = [mpsGraph castTensor:outgoingGradTensor + toType:MPSDataTypeFloat16 + name:@"castGradTensor"]; } - return newCachedGraph; - }); - } - auto incomingGradPlaceholder = native_mps::Placeholder(cachedGraph->incomingGradTensor_, grad_); - auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices); - auto outgoingGradPlaceholder = native_mps::Placeholder(cachedGraph->outgoingGradTensor_, outgoing_gradient); - - NSDictionary *feeds = @{ - incomingGradPlaceholder.getMPSGraphTensor() : incomingGradPlaceholder.getMPSGraphTensorData(), - indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary *results = @{ - outgoingGradPlaceholder.getMPSGraphTensor() : outgoingGradPlaceholder.getMPSGraphTensorData() - }; - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + newCachedGraph->incomingGradTensor_ = incomingGradTensor; + newCachedGraph->indicesTensor_ = indicesTensor; + newCachedGraph->outgoingGradTensor_ = outgoingGradTensor; + } + return newCachedGraph; + }); } - return outgoing_gradient; + auto incomingGradPlaceholder = native_mps::Placeholder(cachedGraph->incomingGradTensor_, grad_); + auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices); + auto outgoingGradPlaceholder = native_mps::Placeholder(cachedGraph->outgoingGradTensor_, outgoing_gradient); + + NSDictionary* feeds = @{ + incomingGradPlaceholder.getMPSGraphTensor() : incomingGradPlaceholder.getMPSGraphTensorData(), + indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = + @{outgoingGradPlaceholder.getMPSGraphTensor() : outgoingGradPlaceholder.getMPSGraphTensorData()}; + native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return outgoing_gradient; } -Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Tensor & value) { - TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor " - "with ", value.dim(), " dimension(s)."); +Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Tensor& value) { + TORCH_CHECK(value.dim() == 0, + "masked_fill_ only supports a 0-dimensional value tensor, but got tensor " + "with ", + value.dim(), + " dimension(s)."); return masked_fill__mps(self, mask, value.item()); } -Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& source) { +Tensor& masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& source) { at::assert_no_internal_overlap(self); - TORCH_CHECK( - self.scalar_type() == source.scalar_type(), - "masked_scatter: expected self and source to have same dtypes but got", - self.scalar_type(), - " and ", - source.scalar_type()); + TORCH_CHECK(self.scalar_type() == source.scalar_type(), + "masked_scatter: expected self and source to have same dtypes but got", + self.scalar_type(), + " and ", + source.scalar_type()); if (self.numel() == 0) { return self; @@ -958,25 +926,22 @@ Tensor embedding_dense_backward_mps( TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool, "masked_scatter: expected BoolTensor or ByteTensor for mask"); - auto mask_temp = (mask.dim() == 0) - ? c10::MaybeOwned::owned(mask.unsqueeze(0)) - : c10::MaybeOwned::borrowed(mask); - auto self_temp = (self.dim() == 0) - ? c10::MaybeOwned::owned(self.unsqueeze(0)) - : c10::MaybeOwned::borrowed(self); + auto mask_temp = + (mask.dim() == 0) ? c10::MaybeOwned::owned(mask.unsqueeze(0)) : c10::MaybeOwned::borrowed(mask); + auto self_temp = + (self.dim() == 0) ? c10::MaybeOwned::owned(self.unsqueeze(0)) : c10::MaybeOwned::borrowed(self); // Cannot reassign to mask_temp and self_temp here! if they are // owning and expand_outplace returns a borrow, the returned borrow // would dangle. auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp); - auto indices = at::native::expandTensors( - *std::get<1>(mask_self_expanded), - c10::List>({*std::move(std::get<0>(mask_self_expanded))}) - ); + auto indices = + at::native::expandTensors(*std::get<1>(mask_self_expanded), + c10::List>({*std::move(std::get<0>(mask_self_expanded))})); // next broadcast all index tensors together try { indices = at::expand_outplace(indices); - } catch (std::exception &e) { + } catch (std::exception& e) { TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"); } @@ -987,15 +952,10 @@ Tensor embedding_dense_backward_mps( c10::List> final_indices; final_indices.reserve(indices.size()); - for (const auto index: indices) { + for (const auto index : indices) { final_indices.push_back(index); } - return at::index_put_out( - self, - *std::get<1>(mask_self_expanded), - final_indices, - source.resize_(indices[0].numel()) - ); + return at::index_put_out(self, *std::get<1>(mask_self_expanded), final_indices, source.resize_(indices[0].numel())); } REGISTER_DISPATCH(index_stub, &index_kernel_mps); diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm index 519de6afa3b85a..e1ee0490061293 100644 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ b/aten/src/ATen/native/mps/operations/Inverse.mm @@ -1,90 +1,82 @@ #include -#include #include -#include +#include #include - +#include namespace at::native { -TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) -{ - TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { - TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); - auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt); - auto cpu_result = result.clone().to("cpu"); - at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); - info.copy_(cpu_info); - result.copy_(cpu_result); - return; - } - - using namespace mps; - MPSStream* stream = getCurrentMPSStream(); - info.zero_(); - - if (A.numel() == 0) { - return; - } - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; +TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); + if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { + TORCH_WARN_ONCE( + "torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); + auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt); + auto cpu_result = result.clone().to("cpu"); + at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); + info.copy_(cpu_info); + result.copy_(cpu_result); + return; + } + + using namespace mps; + MPSStream* stream = getCurrentMPSStream(); + info.zero_(); + + if (A.numel() == 0) { + return; + } + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + + Tensor output = result; + bool isContiguous = true; + if (!result.is_contiguous()) { + output = result.contiguous(); + isContiguous = false; + } + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = "inv_out_mps" + getTensorsStringKey({A}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A); + MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } - Tensor output = result; - bool isContiguous = true; - if (!result.is_contiguous()) { - output = result.contiguous(); - isContiguous = false; + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); } - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - string key = "inv_out_mps" + getTensorsStringKey({A}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) - { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A); - MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor - name: nil]; + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output); - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } + NSDictionary* feeds = + @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()}; - return newCachedGraph; - - }); - cachedGraph = static_cast(tmpCachedGraph); - } + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output); - - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - if (!isContiguous) { - result.copy_(output); - } + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + if (!isContiguous) { + result.copy_(output); } + } } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 529c26ded00249..8988e146d1e5dc 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -6,17 +6,14 @@ using namespace mps; -Tensor _mps_linear( - const Tensor& input, - const Tensor& weight_arg, - const c10::optional& bias_opt) { +Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const c10::optional& bias_opt) { // wT = transpose(weight); // y=x*wT+b auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg; - TORCH_CHECK(input.scalar_type() == ScalarType::Float || - input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs"); + TORCH_CHECK(input.scalar_type() == ScalarType::Float || input.scalar_type() == ScalarType::Half, + "MPS device does not support linear for non-float inputs"); const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt)); bool is_bias_defined = bias.defined(); @@ -24,24 +21,19 @@ Tensor _mps_linear( auto input_size = input.sizes(); std::vector output_size(input_size.begin(), input_size.end() - 1); output_size.push_back(weight.size(0)); - Tensor output = at::native::empty_mps(output_size, - input.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - input.suggest_memory_format()); + Tensor output = at::native::empty_mps( + output_size, input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, input.suggest_memory_format()); TORCH_CHECK(output.is_mps()); - if(output.numel() == 0) { + if (output.numel() == 0) { return output; } - MPSStream *stream = getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* weightTensor_ = nil; MPSGraphTensor* biasTensor_ = nil; @@ -51,14 +43,12 @@ Tensor _mps_linear( MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = "mps_linear" + getTensorsStringKey({input, weight, bias}) ; + string key = "mps_linear" + getTensorsStringKey({input, weight, bias}); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -71,14 +61,11 @@ Tensor _mps_linear( name:nil]; MPSGraphTensor* outputTensor = nil; - if (!is_bias_defined) - { + if (!is_bias_defined) { outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor secondaryTensor:weightTransposeTensor name:nil]; - } - else - { + } else { MPSGraphTensor* inputFlattened = inputTensor; bool doReshape = false; // workaround to improve the performance with 3D+ inputs @@ -92,9 +79,10 @@ Tensor _mps_linear( secondaryTensor:weightTransposeTensor name:nil]; MPSGraphTensor* biasedTensor = [mpsGraph additionWithPrimaryTensor:xMulWTTensor - secondaryTensor:newCachedGraph->biasTensor_ - name:nil]; - outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil] : biasedTensor; + secondaryTensor:newCachedGraph->biasTensor_ + name:nil]; + outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil] + : biasedTensor; } newCachedGraph->inputTensor_ = inputTensor; @@ -110,89 +98,76 @@ Tensor _mps_linear( Placeholder biasPlaceholder = Placeholder(); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); if (is_bias_defined) { biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); } - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } // Shave off '1' present at the end of the shape - if(weight_arg.dim() == 1) { + if (weight_arg.dim() == 1) { // Number of elements in new output shape auto output_sizes = output.sizes(); - std::vector out_shape(output_sizes.begin(), output_sizes.end()-1); + std::vector out_shape(output_sizes.begin(), output_sizes.end() - 1); return output.view(IntArrayRef(out_shape)); } return output; } -Tensor _mps_linear_backward_input( - IntArrayRef input_size, - const Tensor & grad_output, - const Tensor & weight) -{ - TORCH_CHECK(grad_output.is_mps(), - "mps_linear_backward: grad_output needs to be mps layout"); - TORCH_CHECK(weight.device().is_mps() && - (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)), - "mps_linear_backward: unsupported weights data type: ", weight.scalar_type()); - - TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double - || grad_output.scalar_type() == ScalarType::Float - || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); +Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) { + TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout"); + TORCH_CHECK(weight.device().is_mps() && (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)), + "mps_linear_backward: unsupported weights data type: ", + weight.scalar_type()); + + TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double || grad_output.scalar_type() == ScalarType::Float || + grad_output.scalar_type() == ScalarType::Half, + "MPS device does not support linear backward for non-float inputs"); const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous(); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *weightTensor_ = nil; - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* weightTensor_ = nil; + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; - Tensor output = at::native::empty_mps(input_size, - grad_output.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - grad_output.suggest_memory_format()); + Tensor output = at::native::empty_mps( + input_size, grad_output.scalar_type(), c10::nullopt, kMPS, c10::nullopt, grad_output.suggest_memory_format()); TORCH_CHECK(output.is_mps()); if (grad_output.numel() == 0) { return output; } - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSStream *stream= getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - - string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped}); + string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped}); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph *mpsGraph = make_mps_graph(); + MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped); - MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor *outputTensor = - [mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTensor - secondaryTensor: weightTensor - name: nil]; + MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensor + secondaryTensor:weightTensor + name:nil]; newCachedGraph->weightTensor_ = weightTensor; newCachedGraph->gradOutputTensor_ = gradOutputTensor; @@ -211,9 +186,8 @@ Tensor _mps_linear_backward_input( gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); @@ -221,27 +195,27 @@ Tensor _mps_linear_backward_input( } } -std::tuple _mps_linear_backward_weights( - const Tensor& grad_output, const Tensor& input, const Tensor& weight, bool bias_defined) -{ +std::tuple _mps_linear_backward_weights(const Tensor& grad_output, + const Tensor& input, + const Tensor& weight, + bool bias_defined) { TORCH_CHECK(grad_output.is_mps() && input.is_mps(), - "_mps_linear_backward: grad_output and input needs to be mps layout"); - - TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || - grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *weightTensor_ = nil; - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - MPSGraphTensor *biasTensor_ = nil; + "_mps_linear_backward: grad_output and input needs to be mps layout"); + + TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || grad_output.scalar_type() == ScalarType::Half, + "MPS device does not support linear backward for non-float inputs"); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* weightTensor_ = nil; + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* biasTensor_ = nil; }; - auto grad_output_reshaped = grad_output.dim() != 2 ? - grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output; + auto grad_output_reshaped = + grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output; auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input; TORCH_CHECK(grad_output_reshaped.is_mps()); @@ -254,59 +228,52 @@ Tensor _mps_linear_backward_input( c10::nullopt, grad_output.suggest_memory_format()); Tensor bias = at::native::empty_mps({grad_output_reshaped.size(1)}, - grad_output.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - grad_output.suggest_memory_format()); + grad_output.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + grad_output.suggest_memory_format()); TORCH_CHECK(output.is_mps()); TORCH_CHECK(bias.is_mps()); if (grad_output.numel() == 0) { output.zero_(); bias.zero_(); - return std::tuple{ output, bias }; + return std::tuple{output, bias}; } - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSStream *stream= getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - - string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" + - getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); + string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" + + getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph *mpsGraph = make_mps_graph(); + MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); - MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); - MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); + MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped); - MPSGraphTensor *gradOutputTransposeTensor = - [mpsGraph transposeTensor: gradOutputTensor - dimension: -1 - withDimension: -2 - name: nil]; + MPSGraphTensor* gradOutputTransposeTensor = [mpsGraph transposeTensor:gradOutputTensor + dimension:-1 + withDimension:-2 + name:nil]; // grad_weight - MPSGraphTensor *outputTensor = - [mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTransposeTensor - secondaryTensor: inputTensor - name: nil]; - MPSGraphTensor *biasTensor = nil; - if (bias_defined) - { - // grad_bias - biasTensor = [mpsGraph reductionSumWithTensor: gradOutputTensor - axis: 0 - name: nil]; - + MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTransposeTensor + secondaryTensor:inputTensor + name:nil]; + MPSGraphTensor* biasTensor = nil; + if (bias_defined) { + // grad_bias + biasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:0 name:nil]; } newCachedGraph->inputTensor_ = inputTensor; @@ -338,14 +305,14 @@ Tensor _mps_linear_backward_input( runMPSGraph(stream, cachedGraph->graph(), feeds, results); - return std::tuple{ output, bias }; + return std::tuple{output, bias}; } } - -std::tuple mps_linear_backward( - const Tensor& input, const Tensor& grad_output, - const Tensor& weight, std::array output_mask) { +std::tuple mps_linear_backward(const Tensor& input, + const Tensor& grad_output, + const Tensor& weight, + std::array output_mask) { Tensor grad_input, grad_weight, grad_bias; if (output_mask[0]) { grad_input = _mps_linear_backward_input(input.sizes(), grad_output, weight); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 1fe50ad582b20f..20d3e9d4877f7e 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -1,8 +1,8 @@ // Copyright © 2022 Apple Inc. -#include #include #include +#include namespace at::native { @@ -12,22 +12,21 @@ */ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor, - bool& transpose_tensor, - int64_t& ld_tensor, - bool transpose_result, - int64_t m, int64_t n) { + bool& transpose_tensor, + int64_t& ld_tensor, + bool transpose_result, + int64_t m, + int64_t n) { IntArrayRef tensor_strides = tensor.strides(); Tensor tensor_; int fast_dim = transpose_result ? 2 : 1; int leading_dim = transpose_result ? 1 : 2; - if (tensor_strides[fast_dim] == 1 && - (tensor_strides[leading_dim] >= std::max(1, m))) { + if (tensor_strides[fast_dim] == 1 && (tensor_strides[leading_dim] >= std::max(1, m))) { transpose_tensor = false; tensor_ = tensor; ld_tensor = tensor_strides[leading_dim]; - } else if ((tensor_strides[leading_dim] == 1) && - (tensor_strides[fast_dim] >= std::max(1, n))) { + } else if ((tensor_strides[leading_dim] == 1) && (tensor_strides[fast_dim] >= std::max(1, n))) { transpose_tensor = true; tensor_ = tensor; ld_tensor = tensor_strides[fast_dim]; @@ -50,14 +49,13 @@ static Tensor prepare_batch_matrix_by_transposing(const Tensor& tensor, * Helper functions to be used for mm/addmm for detecting the Transpositions * when doing GEMM operations. */ -void prepare_matrices_for_broadcasting( - const Tensor * bias, - const Tensor & self, - const Tensor & other, - const Scalar * beta, - bool * transpose_mat1_times_mat2, - bool & transpose_mat1, - bool & transpose_mat2) { +void prepare_matrices_for_broadcasting(const Tensor* bias, + const Tensor& self, + const Tensor& other, + const Scalar* beta, + bool* transpose_mat1_times_mat2, + bool& transpose_mat1, + bool& transpose_mat2) { TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); if (bias && beta->toDouble() != 0.0f) { TORCH_CHECK(bias->dim() == 2, "tensors must be 2-D"); @@ -79,20 +77,14 @@ void prepare_matrices_for_broadcasting( } } -enum LinearAlgebraOpType { - ADDBMM_OP_TYPE, - BADDBMM_OP_TYPE -}; +enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE }; -Tensor& mm_out_mps_impl( - const Tensor& self, - const Tensor& other, - Tensor& output) { +Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(self.scalar_type() == ScalarType::Double - || self.scalar_type() == ScalarType::Float - || self.scalar_type() == ScalarType::Half, "MPS device does not support mm for non-float inputs"); + TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float || + self.scalar_type() == ScalarType::Half, + "MPS device does not support mm for non-float inputs"); TensorArg args[]{{output, "out", 0}, {self, "mat1", 1}, {other, "mat2", 2}}; checkAllSameGPU("mm", args); @@ -105,47 +97,41 @@ void prepare_matrices_for_broadcasting( return output; } - struct CachedGraph : public mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *selfTensor_ = nil; - MPSGraphTensor *otherTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* selfTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSStream* stream = getCurrentMPSStream(); - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance(); @autoreleasepool { - string key = "mm_out_mps_impl" + getTensorsStringKey({self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *selfTensor = nil; - MPSGraphTensor *otherTensor = nil; - MPSGraphTensor *outputTensor = nil; - - if(self.numel() == 0 || other.numel() == 0) { + MPSGraphTensor* selfTensor = nil; + MPSGraphTensor* otherTensor = nil; + MPSGraphTensor* outputTensor = nil; + if (self.numel() == 0 || other.numel() == 0) { outputTensor = [mpsGraph constantWithScalar:0. shape:getMPSShape(output_sizes) - dataType:getMPSDataType(output)]; - - } - else { + dataType:getMPSDataType(output)]; + } else { selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); - otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); + otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfTensor secondaryTensor:otherTensor name:nil]; @@ -157,11 +143,11 @@ void prepare_matrices_for_broadcasting( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(); Placeholder otherPlaceholder = Placeholder(); - if(!(self.numel() == 0 || other.numel() == 0)) { + if (!(self.numel() == 0 || other.numel() == 0)) { selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); } @@ -169,15 +155,14 @@ void prepare_matrices_for_broadcasting( NSDictionary* feeds = nil; - if(!(self.numel() == 0 || other.numel() == 0)) + if (!(self.numel() == 0 || other.numel() == 0)) feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -185,26 +170,25 @@ void prepare_matrices_for_broadcasting( return output; } - -Tensor addr_mps(const Tensor& self, - const Tensor& vec1, const Tensor& vec2, - const Scalar& beta, const Scalar& alpha) { +Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { Tensor result = at::empty({0}, self.options()); - addr_out_mps(self, vec1,vec2,beta,alpha,result); + addr_out_mps(self, vec1, vec2, beta, alpha, result); return result; } - Tensor& addr_out_mps(const Tensor& self, - const Tensor& vec1, const Tensor& vec2, - const Scalar& beta, const Scalar& alpha, Tensor &result) { + const Tensor& vec1, + const Tensor& vec2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { using namespace mps; TORCH_CHECK(result.is_mps()); TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D"); - TORCH_CHECK(vec1.scalar_type() == ScalarType::Double - || vec1.scalar_type() == ScalarType::Float - || vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input"); + TORCH_CHECK(vec1.scalar_type() == ScalarType::Double || vec1.scalar_type() == ScalarType::Float || + vec1.scalar_type() == ScalarType::Half, + "MPS device does not support addr for non-float input"); TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}}; checkAllSameGPU(__func__, args); @@ -239,37 +223,34 @@ Tensor addr_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); bool is_beta_non_zero = beta.toDouble() != 0.0; - MPSShape* inputShape = @[@(vec1.numel()), @(1)]; - MPSShape* otherShape = @[@(1), @(vec2.numel())]; - - struct CachedGraph : public mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *vec1Tensor_ = nil; - MPSGraphTensor *vec2Tensor_ = nil; - MPSGraphTensor *selfTensor_ = nil; - MPSGraphTensor *resultTensor_ = nil; + MPSShape* inputShape = @[ @(vec1.numel()), @(1) ]; + MPSShape* otherShape = @[ @(1), @(vec2.numel()) ]; + + struct CachedGraph : public mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* vec1Tensor_ = nil; + MPSGraphTensor* vec2Tensor_ = nil; + MPSGraphTensor* selfTensor_ = nil; + MPSGraphTensor* resultTensor_ = nil; }; - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance(); @autoreleasepool { - string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) - + ":" + to_string(beta.toDouble()) - + ":" + to_string(alpha.toDouble()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) + + ":" + to_string(alpha.toDouble()); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape); - MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape); - MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_); + MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape); + MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape); + MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_); // Intermediate as placeholder MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1 @@ -280,7 +261,7 @@ Tensor addr_mps(const Tensor& self, MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble() dataType:getMPSScalarType((*self_).scalar_type())]; MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble() - dataType:getMPSScalarType(vec1.scalar_type())]; + dataType:getMPSScalarType(vec1.scalar_type())]; // Intermediates for multiplying by beta and alpha MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor @@ -298,7 +279,7 @@ Tensor addr_mps(const Tensor& self, resultTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor secondaryTensor:selfTimesBetaTensor name:@"MM/beta*input+alpha*(vec1@vec2)"]; - } + } newCachedGraph->vec1Tensor_ = t1; newCachedGraph->vec2Tensor_ = t2; @@ -307,7 +288,7 @@ Tensor addr_mps(const Tensor& self, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape); @@ -321,9 +302,8 @@ Tensor addr_mps(const Tensor& self, selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()}; mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -331,20 +311,19 @@ Tensor addr_mps(const Tensor& self, return result; } -Tensor& addmm_out_mps_impl( - const Tensor& bias, - const Tensor& self, // input - const Tensor& other, // weight - const Scalar& beta, - const Scalar& alpha, - Tensor& output) { +Tensor& addmm_out_mps_impl(const Tensor& bias, + const Tensor& self, // input + const Tensor& other, // weight + const Scalar& beta, + const Scalar& alpha, + Tensor& output) { using namespace mps; TORCH_CHECK(output.is_mps()); TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(self.scalar_type() == ScalarType::Double - || self.scalar_type() == ScalarType::Float - || self.scalar_type() == ScalarType::Half, "MPS device does not support addmm for non-float input"); + TORCH_CHECK(self.scalar_type() == ScalarType::Double || self.scalar_type() == ScalarType::Float || + self.scalar_type() == ScalarType::Half, + "MPS device does not support addmm for non-float input"); TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}}; checkAllSameGPU(__func__, args); @@ -378,62 +357,52 @@ Tensor addr_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); bool transpose_mat1_times_mat2 = false; - bool transpose_mat1 = false; - bool transpose_mat2 = false; - bool is_beta_non_zero = beta.toDouble() != 0.0; - - prepare_matrices_for_broadcasting(&(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2); - - struct CachedGraph : public mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *selfTensor_ = nil; - MPSGraphTensor *otherTensor_ = nil; - MPSGraphTensor *biasTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + bool transpose_mat1 = false; + bool transpose_mat2 = false; + bool is_beta_non_zero = beta.toDouble() != 0.0; + + prepare_matrices_for_broadcasting( + &(*bias_), self, other, &beta, &transpose_mat1_times_mat2, transpose_mat1, transpose_mat2); + + struct CachedGraph : public mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* selfTensor_ = nil; + MPSGraphTensor* otherTensor_ = nil; + MPSGraphTensor* biasTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance(); @autoreleasepool { - string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) - + ":" + to_string(transpose_mat1) + ":" + to_string(transpose_mat2) - + ":" + to_string(beta.toDouble()) - + ":" + to_string(alpha.toDouble()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(transpose_mat1) + + ":" + to_string(transpose_mat2) + ":" + to_string(beta.toDouble()) + ":" + to_string(alpha.toDouble()); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor *biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_); + MPSGraphTensor* selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other); + MPSGraphTensor* biasTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *bias_); MPSGraphTensor* t1 = nil; MPSGraphTensor* t2 = nil; - if(transpose_mat1) - t1 = [mpsGraph transposeTensor:selfTensor - dimension:-1 - withDimension:-2 - name:nil]; + if (transpose_mat1) + t1 = [mpsGraph transposeTensor:selfTensor dimension:-1 withDimension:-2 name:nil]; else t1 = selfTensor; - if(transpose_mat2) - t2 = [mpsGraph transposeTensor:otherTensor - dimension:-1 - withDimension:-2 - name:nil]; + if (transpose_mat2) + t2 = [mpsGraph transposeTensor:otherTensor dimension:-1 withDimension:-2 name:nil]; else t2 = otherTensor; - // TODO: Use alpha and beta here with fill_.Scalar and mul // Intermediate as placeholder MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1 @@ -444,7 +413,7 @@ Tensor addr_mps(const Tensor& self, MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble() dataType:getMPSScalarType((*bias_).scalar_type())]; MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble() - dataType:getMPSScalarType(self.scalar_type())]; + dataType:getMPSScalarType(self.scalar_type())]; // Intermediates for multiplying by beta and alpha MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor @@ -458,17 +427,14 @@ Tensor addr_mps(const Tensor& self, } if (transpose_mat1_times_mat2) - biasTimesBetaTensor = [mpsGraph transposeTensor: biasTimesBetaTensor - dimension: -1 - withDimension: -2 - name: nil]; + biasTimesBetaTensor = [mpsGraph transposeTensor:biasTimesBetaTensor dimension:-1 withDimension:-2 name:nil]; MPSGraphTensor* outputTensor = productTimesAlphaTensor; if (is_beta_non_zero) { outputTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor secondaryTensor:biasTimesBetaTensor name:@"MM/beta*input + alpha*(mat1@mat2)"]; - } + } newCachedGraph->selfTensor_ = selfTensor; newCachedGraph->otherTensor_ = otherTensor; @@ -477,7 +443,7 @@ Tensor addr_mps(const Tensor& self, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); @@ -491,9 +457,8 @@ Tensor addr_mps(const Tensor& self, biasPlaceholder.getMPSGraphTensor() : biasPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -501,16 +466,12 @@ Tensor addr_mps(const Tensor& self, return output; } - -Tensor& bmm_out_mps_impl( - const Tensor & batch1, - const Tensor & batch2, - Tensor & result) { +Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { using namespace mps; - TORCH_CHECK(batch1.scalar_type() == ScalarType::Double - || batch1.scalar_type() == ScalarType::Float - || batch1.scalar_type() == ScalarType::Half, "MPS device does not support bmm for non-float inputs"); + TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float || + batch1.scalar_type() == ScalarType::Half, + "MPS device does not support bmm for non-float inputs"); if (batch1.numel() == 0 || batch2.numel() == 0) { result.zero_(); @@ -519,31 +480,29 @@ Tensor addr_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); - struct CachedGraph : public mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *batch1Tensor_ = nil; - MPSGraphTensor *batch2Tensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* batch1Tensor_ = nil; + MPSGraphTensor* batch2Tensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance(); @autoreleasepool { string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1); - MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2); + MPSGraphTensor* batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1); + MPSGraphTensor* batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2); MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor secondaryTensor:batch2Tensor @@ -555,7 +514,7 @@ Tensor addr_mps(const Tensor& self, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1); Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2); @@ -566,9 +525,8 @@ Tensor addr_mps(const Tensor& self, batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(), }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -576,14 +534,13 @@ Tensor addr_mps(const Tensor& self, return result; } -Tensor& addbmm_or_baddbmm_out_mps_impl( - const Tensor & input, - const Tensor & batch1, - const Tensor & batch2, - const Scalar & beta, - const Scalar & alpha, - Tensor & result, - LinearAlgebraOpType opType) { +Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result, + LinearAlgebraOpType opType) { using namespace mps; TORCH_CHECK(input.is_mps()); @@ -591,22 +548,29 @@ Tensor addr_mps(const Tensor& self, TORCH_CHECK(batch2.is_mps()); TORCH_CHECK(result.is_mps()); - TORCH_CHECK(batch1.scalar_type() == ScalarType::Double - || batch1.scalar_type() == ScalarType::Float - || batch1.scalar_type() == ScalarType::Half, "MPS device does not support addbmm or baddbmm for non-float inputs"); + TORCH_CHECK(batch1.scalar_type() == ScalarType::Double || batch1.scalar_type() == ScalarType::Float || + batch1.scalar_type() == ScalarType::Half, + "MPS device does not support addbmm or baddbmm for non-float inputs"); TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); TORCH_CHECK(batch1.size(0) == batch2.size(0), - "batch1 and batch2 must have same number of batches, got ", - batch1.size(0), " and ", batch2.size(0)); + "batch1 and batch2 must have same number of batches, got ", + batch1.size(0), + " and ", + batch2.size(0)); TORCH_CHECK(batch1.size(2) == batch2.size(1), - "Incompatible matrix sizes for bmm (", - batch1.size(1), "x", batch1.size(2), " and ", - batch2.size(1), "x", batch2.size(2), ")"); - - if (opType == ADDBMM_OP_TYPE) - { + "Incompatible matrix sizes for bmm (", + batch1.size(1), + "x", + batch1.size(2), + " and ", + batch2.size(1), + "x", + batch2.size(2), + ")"); + + if (opType == ADDBMM_OP_TYPE) { result.resize_as_(input); const int64_t num_batches = batch1.size(0); @@ -619,42 +583,39 @@ Tensor addr_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); - struct CachedGraph : public mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *batch1Tensor_ = nil; - MPSGraphTensor *batch2Tensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* batch1Tensor_ = nil; + MPSGraphTensor* batch2Tensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + mps::MPSGraphCache* cache_ = mps::MPSGraphCache::getInstance(); @autoreleasepool { string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl"); - key += getTensorsStringKey({batch1, batch2, input}) - + ":" + to_string(beta.toDouble()) - + ":" + to_string(alpha.toDouble()); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { + key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" + + to_string(alpha.toDouble()); - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool{ - MPSGraph *mpsGraph = mps::make_mps_graph(); + @autoreleasepool { + MPSGraph* mpsGraph = mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1); - MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2); + MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1); + MPSGraphTensor* batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2); // Intermediates for beta and alpha - MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar: beta.toDouble() - dataType: getMPSScalarType(input.scalar_type())]; - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar: alpha.toDouble() - dataType: getMPSScalarType(batch1.scalar_type())]; + MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble() + dataType:getMPSScalarType(input.scalar_type())]; + MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble() + dataType:getMPSScalarType(batch1.scalar_type())]; MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor secondaryTensor:batch2Tensor @@ -662,46 +623,46 @@ Tensor addr_mps(const Tensor& self, MPSGraphTensor* reductionSumTensor = productTensor; if (opType == ADDBMM_OP_TYPE) { - reductionSumTensor = [mpsGraph reductionSumWithTensor: productTensor - axis: 0 - name: @"reductionSum(batch1@batch2)"]; + reductionSumTensor = [mpsGraph reductionSumWithTensor:productTensor + axis:0 + name:@"reductionSum(batch1@batch2)"]; } // Intermediates for multiplying by beta and alpha - MPSGraphTensor* reductionSumTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor: reductionSumTensor - secondaryTensor: alphaTensor - name: @"alpha*(batch1@batch2)"]; - MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: betaTensor - name: @"beta*input"]; + MPSGraphTensor* reductionSumTimesAlphaTensor = + [mpsGraph multiplicationWithPrimaryTensor:reductionSumTensor + secondaryTensor:alphaTensor + name:@"alpha*(batch1@batch2)"]; + MPSGraphTensor* biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:betaTensor + name:@"beta*input"]; MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:reductionSumTimesAlphaTensor secondaryTensor:biasTimesBetaTensor name:@"beta*input + alpha*(batch1@batch2)"]; - newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->batch1Tensor_ = batch1Tensor; newCachedGraph->batch2Tensor_ = batch2Tensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1); Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), batch1Placeholder.getMPSGraphTensor() : batch1Placeholder.getMPSGraphTensorData(), batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(), }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -713,40 +674,67 @@ Tensor addr_mps(const Tensor& self, mm_out_mps_impl(self, mat2, const_cast(result)); } -TORCH_IMPL_FUNC(addmm_out_mps)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) { +TORCH_IMPL_FUNC(addmm_out_mps) +(const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + const Tensor& result) { addmm_out_mps_impl(self, mat1, mat2, beta, alpha, const_cast(result)); } -TORCH_IMPL_FUNC(bmm_out_mps) (const Tensor & batch1, const Tensor & batch2, const Tensor & result) { +TORCH_IMPL_FUNC(bmm_out_mps)(const Tensor& batch1, const Tensor& batch2, const Tensor& result) { bmm_out_mps_impl(batch1, batch2, const_cast(result)); } -TORCH_IMPL_FUNC(baddbmm_out_mps) (const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) { +TORCH_IMPL_FUNC(baddbmm_out_mps) +(const Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + const Tensor& result) { addbmm_or_baddbmm_out_mps_impl(self, batch1, batch2, beta, alpha, const_cast(result), BADDBMM_OP_TYPE); } -Tensor& addbmm_out_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) { +Tensor& addbmm_out_mps(const Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out"); addbmm_or_baddbmm_out_mps_impl(*b_self, batch1, batch2, beta, alpha, result, ADDBMM_OP_TYPE); return result; } -Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { +Tensor addbmm_mps(const Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha) { Tensor result = at::empty({0}, self.options()); return addbmm_out_mps(self, batch1, batch2, beta, alpha, result); } -Tensor &addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { +Tensor& addbmm_mps_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { return addbmm_out_mps(self, batch1, batch2, beta, alpha, self); } -Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) { +Tensor& linalg_solve_triangular_mps_impl(const Tensor& A, + const Tensor& B, + bool upper, + bool transpose, + bool left, + bool unitriangular, + Tensor& out) { using namespace mps; checkInputsSolver(A, B, left, "linalg.solve_triangular"); Tensor A_t, B_t; - std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr); + std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr); at::native::resize_output(out, B_t.sizes()); if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) { @@ -768,7 +756,7 @@ Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2 MPSStream* mpsStream = getCurrentMPSStream(); id device = MPSDevice::getInstance()->device(); - dispatch_sync(mpsStream->queue(), ^(){ + dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id commandBuffer = mpsStream->commandBuffer(); uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1; @@ -779,7 +767,7 @@ Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2 uint64_t aElemSize = A_.element_size(); uint64_t bElemSize = B_.element_size(); - MPSMatrixSolveTriangular *filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device + MPSMatrixSolveTriangular* filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device right:!left upper:upper transpose:transpose @@ -794,22 +782,24 @@ Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2 rowBytes:aCols * aElemSize matrixBytes:aRows * aCols * aElemSize dataType:getMPSDataType(A_)]; - MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows - columns:bCols - matrices:batchSize - rowBytes:bCols * bElemSize - matrixBytes:bRows * bCols * bElemSize - dataType:getMPSDataType(B_)]; - for (const auto i: c10::irange(batchSize)) { + MPSMatrixDescriptor* rightHandSideMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:bRows + columns:bCols + matrices:batchSize + rowBytes:bCols * bElemSize + matrixBytes:bRows * bCols * bElemSize + dataType:getMPSDataType(B_)]; + for (const auto i : c10::irange(batchSize)) { const uint64_t aBatchOffset = i * aRows * aCols; const uint64_t bBatchOffset = i * bRows * bCols; MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer offset:(A_t.storage_offset() + aBatchOffset) * aElemSize descriptor:sourceMatrixDesc] autorelease]; - MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer - offset:(B_t.storage_offset() + bBatchOffset) * bElemSize - descriptor:rightHandSideMatrixDesc] autorelease]; - MPSMatrix *solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer + MPSMatrix* rightHandSideMatrix = + [[[MPSMatrix alloc] initWithBuffer:bBuffer + offset:(B_t.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer offset:(out.storage_offset() + bBatchOffset) * bElemSize descriptor:rightHandSideMatrixDesc] autorelease]; @@ -824,7 +814,12 @@ Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2 return out; } -Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) { +Tensor& linalg_solve_triangular_mps_out(const Tensor& A, + const Tensor& B, + bool upper, + bool left, + bool unitriangular, + Tensor& out) { return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out); } @@ -834,7 +829,14 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, return out; } -TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) { +TORCH_IMPL_FUNC(triangular_solve_mps_out) +(const Tensor& self, + const Tensor& A, + bool upper, + bool transpose, + bool unitriangular, + const Tensor& result, + const Tensor& clone_A) { clone_A.copy_(A); Tensor out = empty_mps({0}, A.scalar_type(), c10::nullopt, kMPS, c10::nullopt, MemoryFormat::Contiguous); linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index ad21c1a7fd03f5..1b86f4e11defa2 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -5,717 +5,670 @@ namespace at::native { namespace mps { -string reductionToString(int64_t reduction) -{ - switch(reduction) { - case Reduction::Mean: return "Mean"; - case Reduction::Sum: return "Sum"; - default: return "None"; - } +string reductionToString(int64_t reduction) { + switch (reduction) { + case Reduction::Mean: + return "Mean"; + case Reduction::Sum: + return "Sum"; + default: + return "None"; + } } -MPSGraphTensor* reduceTensor(MPSGraphTensor *tensor, int64_t reduction, MPSGraph *mpsGraph, NSUInteger axesCount) -{ - NSMutableArray *axes = [NSMutableArray arrayWithCapacity:axesCount]; - for (NSUInteger i = 0; i < axesCount; i++) axes[i] = @(i); - - switch(reduction) { - case Reduction::Mean: - return [mpsGraph meanOfTensor: tensor axes: axes name: @"reductionMeanTensor"]; - case Reduction::Sum: - return [mpsGraph reductionSumWithTensor: tensor axes: axes name: @"reductionSumTensor"]; - default: - assert(reduction == Reduction::None); - return tensor; - } +MPSGraphTensor* reduceTensor(MPSGraphTensor* tensor, int64_t reduction, MPSGraph* mpsGraph, NSUInteger axesCount) { + NSMutableArray* axes = [NSMutableArray arrayWithCapacity:axesCount]; + for (NSUInteger i = 0; i < axesCount; i++) + axes[i] = @(i); + + switch (reduction) { + case Reduction::Mean: + return [mpsGraph meanOfTensor:tensor axes:axes name:@"reductionMeanTensor"]; + case Reduction::Sum: + return [mpsGraph reductionSumWithTensor:tensor axes:axes name:@"reductionSumTensor"]; + default: + assert(reduction == Reduction::None); + return tensor; + } } -Tensor& mse_loss_backward_out_impl(const Tensor& grad_output, const Tensor& input, const Tensor& target, - int64_t reduction, Tensor& grad_input, const string op_name) -{ - TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") - auto norm = reduction == Reduction::Mean ? 2. / static_cast(input.numel()) : 2.; - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor = nil, *targetTensor = nil; - MPSGraphTensor *gradInputTensor = nil, *gradOutputTensor = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); +Tensor& mse_loss_backward_out_impl(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + Tensor& grad_input, + const string op_name) { + TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") + auto norm = reduction == Reduction::Mean ? 2. / static_cast(input.numel()) : 2.; + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor = nil, *targetTensor = nil; + MPSGraphTensor *gradInputTensor = nil, *gradOutputTensor = nil; + }; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) + + getTensorsStringKey({input, target, grad_output}); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool { - string key = op_name + reductionToString(reduction) + ":" + - to_string(grad_input.sizes()[1]) + - getTensorsStringKey({input, target, grad_output}); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); - newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - - MPSGraphTensor *normTensor = [mpsGraph constantWithScalar: norm - dataType: MPSDataTypeFloat32]; - MPSGraphTensor *diffTensor = [mpsGraph subtractionWithPrimaryTensor: newCachedGraph->inputTensor - secondaryTensor: newCachedGraph->targetTensor - name: nil]; - MPSGraphTensor *diffGradientTensor = [mpsGraph multiplicationWithPrimaryTensor: diffTensor - secondaryTensor: newCachedGraph->gradOutputTensor - name: nil]; - newCachedGraph->gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: diffGradientTensor - secondaryTensor: normTensor - name: nil]; - } - return newCachedGraph; - })); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + + MPSGraphTensor* normTensor = [mpsGraph constantWithScalar:norm dataType:MPSDataTypeFloat32]; + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:newCachedGraph->inputTensor + secondaryTensor:newCachedGraph->targetTensor + name:nil]; + MPSGraphTensor* diffGradientTensor = + [mpsGraph multiplicationWithPrimaryTensor:diffTensor + secondaryTensor:newCachedGraph->gradOutputTensor + name:nil]; + newCachedGraph->gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:diffGradientTensor + secondaryTensor:normTensor + name:nil]; } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor, grad_input); - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output); - - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData(), - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() :gradInputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + return newCachedGraph; + })); } + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor, grad_input); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output); - return grad_input; + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData(), + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + + return grad_input; } // namespace to localize the CachedGraph struct for Binary Cross Entropy -namespace BCELoss -{ - -struct CachedGraph : public MPSCachedGraph -{ - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor = nil, *targetTensor = nil; - // gradOutput only used on backward pass - MPSGraphTensor *weightTensor = nil, *gradOutputTensor = nil; - // lossTensor used for forward, and gradInputTensor for backward pass - union { MPSGraphTensor *lossTensor = nil; MPSGraphTensor *gradInputTensor; }; +namespace BCELoss { + +struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor = nil, *targetTensor = nil; + // gradOutput only used on backward pass + MPSGraphTensor *weightTensor = nil, *gradOutputTensor = nil; + // lossTensor used for forward, and gradInputTensor for backward pass + union { + MPSGraphTensor* lossTensor = nil; + MPSGraphTensor* gradInputTensor; + }; }; -MPSGraphTensor* bce_forward_mps(CachedGraph *bceGraph) -{ - MPSGraph *mpsGraph = bceGraph->graph(); - - // Forward BCE: L = -w (y ln(x) + (1-y) ln(1-x)) - MPSGraphTensor *one = [mpsGraph constantWithScalar: 1.0 - dataType: MPSDataTypeFloat32]; - // -100 is the hard limit value defined in BCELoss Spec. to clamp the log - MPSGraphTensor *neg100 = [mpsGraph constantWithScalar: -100.0 - dataType: MPSDataTypeFloat32]; - // 1 - x - MPSGraphTensor *one_Input = [mpsGraph subtractionWithPrimaryTensor: one - secondaryTensor: bceGraph->inputTensor - name: nil]; - // log(x) - MPSGraphTensor *logInput = [mpsGraph logarithmWithTensor: bceGraph->inputTensor - name: nil]; - // max(log(x), -100) - MPSGraphTensor *clampedLogInput = [mpsGraph maximumWithPrimaryTensor: logInput - secondaryTensor: neg100 - name: nil]; - // log(1 - x) - MPSGraphTensor *log1_Input = [mpsGraph logarithmWithTensor: one_Input - name: nil]; - // max(log(1 - x), -100) - MPSGraphTensor *clampedLog1_Input = [mpsGraph maximumWithPrimaryTensor: log1_Input - secondaryTensor: neg100 - name: nil]; - // (y - 1) resulted from -(1 - y) - MPSGraphTensor *target_1 = [mpsGraph subtractionWithPrimaryTensor: bceGraph->targetTensor - secondaryTensor: one - name: nil]; - // (y - 1) * max(log(1 - x), -100) - MPSGraphTensor *target_1TimesLog1_Input = [mpsGraph multiplicationWithPrimaryTensor: target_1 - secondaryTensor: clampedLog1_Input - name: nil]; - // y * max(log(x), -100) - MPSGraphTensor *targetTimesLogInput = [mpsGraph multiplicationWithPrimaryTensor: bceGraph->targetTensor - secondaryTensor: clampedLogInput - name: nil]; - // ((y - 1) * max(log(1 - x), -100)) - (y * max(log(x), -100)) - MPSGraphTensor *bceLoss = [mpsGraph subtractionWithPrimaryTensor: target_1TimesLog1_Input - secondaryTensor: targetTimesLogInput - name: nil]; - return bceLoss; +MPSGraphTensor* bce_forward_mps(CachedGraph* bceGraph) { + MPSGraph* mpsGraph = bceGraph->graph(); + + // Forward BCE: L = -w (y ln(x) + (1-y) ln(1-x)) + MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeFloat32]; + // -100 is the hard limit value defined in BCELoss Spec. to clamp the log + MPSGraphTensor* neg100 = [mpsGraph constantWithScalar:-100.0 dataType:MPSDataTypeFloat32]; + // 1 - x + MPSGraphTensor* one_Input = [mpsGraph subtractionWithPrimaryTensor:one + secondaryTensor:bceGraph->inputTensor + name:nil]; + // log(x) + MPSGraphTensor* logInput = [mpsGraph logarithmWithTensor:bceGraph->inputTensor name:nil]; + // max(log(x), -100) + MPSGraphTensor* clampedLogInput = [mpsGraph maximumWithPrimaryTensor:logInput secondaryTensor:neg100 name:nil]; + // log(1 - x) + MPSGraphTensor* log1_Input = [mpsGraph logarithmWithTensor:one_Input name:nil]; + // max(log(1 - x), -100) + MPSGraphTensor* clampedLog1_Input = [mpsGraph maximumWithPrimaryTensor:log1_Input secondaryTensor:neg100 name:nil]; + // (y - 1) resulted from -(1 - y) + MPSGraphTensor* target_1 = [mpsGraph subtractionWithPrimaryTensor:bceGraph->targetTensor + secondaryTensor:one + name:nil]; + // (y - 1) * max(log(1 - x), -100) + MPSGraphTensor* target_1TimesLog1_Input = [mpsGraph multiplicationWithPrimaryTensor:target_1 + secondaryTensor:clampedLog1_Input + name:nil]; + // y * max(log(x), -100) + MPSGraphTensor* targetTimesLogInput = [mpsGraph multiplicationWithPrimaryTensor:bceGraph->targetTensor + secondaryTensor:clampedLogInput + name:nil]; + // ((y - 1) * max(log(1 - x), -100)) - (y * max(log(x), -100)) + MPSGraphTensor* bceLoss = [mpsGraph subtractionWithPrimaryTensor:target_1TimesLog1_Input + secondaryTensor:targetTimesLogInput + name:nil]; + return bceLoss; } -MPSGraphTensor* bce_backward_mps(CachedGraph *bceGraph) -{ - MPSGraph *mpsGraph = bceGraph->graph(); - - // Backward BCE: d(L)/d(x) = -w (y - x) / (x - x^2) - MPSGraphTensor *one = [mpsGraph constantWithScalar: 1.0 - dataType: MPSDataTypeFloat32]; - // epsilon used to clamp the grad input denominator - MPSGraphTensor *epsilon = [mpsGraph constantWithScalar: 1e-12 - dataType: MPSDataTypeFloat32]; - // 1 - x - MPSGraphTensor *one_Input = [mpsGraph subtractionWithPrimaryTensor: one - secondaryTensor: bceGraph->inputTensor - name: nil]; - // x * (1 - x) - MPSGraphTensor *inputTimes1_Input = [mpsGraph multiplicationWithPrimaryTensor: bceGraph->inputTensor - secondaryTensor: one_Input - name: nil]; - // max(x * (1 - x), epsilon) - MPSGraphTensor *gradInputDenominator = [mpsGraph maximumWithPrimaryTensor: inputTimes1_Input - secondaryTensor: epsilon - name: nil]; - // (x - y) - MPSGraphTensor *input_target = [mpsGraph subtractionWithPrimaryTensor: bceGraph->inputTensor - secondaryTensor: bceGraph->targetTensor - name: nil]; - // (x - y) / max(x * (1 - x), epsilon) - MPSGraphTensor *inputDivGradInputDenom = [mpsGraph divisionWithPrimaryTensor: input_target - secondaryTensor: gradInputDenominator - name: nil]; - // gradOutput * (((x - y) / max(x * (1 - x), epsilon))) - MPSGraphTensor *gradInput = [mpsGraph multiplicationWithPrimaryTensor: bceGraph->gradOutputTensor - secondaryTensor: inputDivGradInputDenom - name: nil]; - return gradInput; +MPSGraphTensor* bce_backward_mps(CachedGraph* bceGraph) { + MPSGraph* mpsGraph = bceGraph->graph(); + + // Backward BCE: d(L)/d(x) = -w (y - x) / (x - x^2) + MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeFloat32]; + // epsilon used to clamp the grad input denominator + MPSGraphTensor* epsilon = [mpsGraph constantWithScalar:1e-12 dataType:MPSDataTypeFloat32]; + // 1 - x + MPSGraphTensor* one_Input = [mpsGraph subtractionWithPrimaryTensor:one + secondaryTensor:bceGraph->inputTensor + name:nil]; + // x * (1 - x) + MPSGraphTensor* inputTimes1_Input = [mpsGraph multiplicationWithPrimaryTensor:bceGraph->inputTensor + secondaryTensor:one_Input + name:nil]; + // max(x * (1 - x), epsilon) + MPSGraphTensor* gradInputDenominator = [mpsGraph maximumWithPrimaryTensor:inputTimes1_Input + secondaryTensor:epsilon + name:nil]; + // (x - y) + MPSGraphTensor* input_target = [mpsGraph subtractionWithPrimaryTensor:bceGraph->inputTensor + secondaryTensor:bceGraph->targetTensor + name:nil]; + // (x - y) / max(x * (1 - x), epsilon) + MPSGraphTensor* inputDivGradInputDenom = [mpsGraph divisionWithPrimaryTensor:input_target + secondaryTensor:gradInputDenominator + name:nil]; + // gradOutput * (((x - y) / max(x * (1 - x), epsilon))) + MPSGraphTensor* gradInput = [mpsGraph multiplicationWithPrimaryTensor:bceGraph->gradOutputTensor + secondaryTensor:inputDivGradInputDenom + name:nil]; + return gradInput; } // Binary Cross Enropy (Forward/Backward BCELoss) // NOTE: "loss" tensor would be "grad_input" if it's a backward pass -Tensor& bce_loss_out_impl(const Tensor& input, const Tensor& target, - const c10::optional& weight_opt, int64_t reduction, Tensor& loss, - const c10::optional& grad_output_opt, const string op_name) -{ - // TODO: add sanity check for the elements of input tensor to be within [0..1] - TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") +Tensor& bce_loss_out_impl(const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + Tensor& loss, + const c10::optional& grad_output_opt, + const string op_name) { + // TODO: add sanity check for the elements of input tensor to be within [0..1] + TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - c10::MaybeOwned grad_output_maybe_owned = at::borrow_from_optional_tensor(grad_output_opt); - const Tensor& weight = *weight_maybe_owned; - const Tensor& grad_output = *grad_output_maybe_owned; + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + c10::MaybeOwned grad_output_maybe_owned = at::borrow_from_optional_tensor(grad_output_opt); + const Tensor& weight = *weight_maybe_owned; + const Tensor& grad_output = *grad_output_maybe_owned; - loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({})); - TORCH_CHECK(loss.is_mps()); + loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({})); + TORCH_CHECK(loss.is_mps()); - Tensor loss_squeezed = at::squeeze(loss); - Tensor input_squeezed = at::squeeze(input); - Tensor target_squeezed = at::squeeze(target); + Tensor loss_squeezed = at::squeeze(loss); + Tensor input_squeezed = at::squeeze(input); + Tensor target_squeezed = at::squeeze(target); - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight}); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed); - newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed); - - MPSGraphTensor *bceLossUnweighted = nil; - // if grad_output is defined, then it's a backward pass - if (grad_output.defined()) { - newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - bceLossUnweighted = bce_backward_mps(newCachedGraph); - } else { - bceLossUnweighted = bce_forward_mps(newCachedGraph); - } - - MPSGraphTensor *bceLoss = bceLossUnweighted; - if (weight.defined()) { - newCachedGraph->weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); - bceLoss = [mpsGraph multiplicationWithPrimaryTensor: bceLossUnweighted - secondaryTensor: newCachedGraph->weightTensor - name: nil]; - } - - if (grad_output.defined()) { - if (reduction == at::Reduction::Mean) { - MPSGraphTensor *inputNumel = [mpsGraph constantWithScalar: static_cast(input.numel()) - dataType: MPSDataTypeFloat32]; - newCachedGraph->gradInputTensor = [mpsGraph divisionWithPrimaryTensor: bceLoss - secondaryTensor: inputNumel - name: nil]; - } else { - newCachedGraph->gradInputTensor = bceLoss; - } - } else { - newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size()); - } - } - return newCachedGraph; - })); - } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed); - Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed); + @autoreleasepool { + string key = + op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight}); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); - if (weight.defined()) { - Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor, weight); - feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); - } - if (grad_output.defined()) { - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output); - feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed); + newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed); + + MPSGraphTensor* bceLossUnweighted = nil; + // if grad_output is defined, then it's a backward pass + if (grad_output.defined()) { + newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + bceLossUnweighted = bce_backward_mps(newCachedGraph); + } else { + bceLossUnweighted = bce_forward_mps(newCachedGraph); + } + + MPSGraphTensor* bceLoss = bceLossUnweighted; + if (weight.defined()) { + newCachedGraph->weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); + bceLoss = [mpsGraph multiplicationWithPrimaryTensor:bceLossUnweighted + secondaryTensor:newCachedGraph->weightTensor + name:nil]; + } + + if (grad_output.defined()) { + if (reduction == at::Reduction::Mean) { + MPSGraphTensor* inputNumel = [mpsGraph constantWithScalar:static_cast(input.numel()) + dataType:MPSDataTypeFloat32]; + newCachedGraph->gradInputTensor = [mpsGraph divisionWithPrimaryTensor:bceLoss + secondaryTensor:inputNumel + name:nil]; + } else { + newCachedGraph->gradInputTensor = bceLoss; + } + } else { + newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size()); + } } + return newCachedGraph; + })); + } + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed); + Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed); - NSDictionary* results = @{ - lossPlaceholder.getMPSGraphTensor() : lossPlaceholder.getMPSGraphTensorData() - }; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); + if (weight.defined()) { + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor, weight); + feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); } + if (grad_output.defined()) { + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output); + feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); + } + + NSDictionary* results = + @{lossPlaceholder.getMPSGraphTensor() : lossPlaceholder.getMPSGraphTensorData()}; - return loss; + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + + return loss; } } // namespace BCELoss // NLLLoss -void nllnd_loss_backward_impl( - Tensor& grad_input_arg, - const Tensor& grad_output_arg, - const Tensor& input_arg, - const Tensor& target_arg, - const Tensor& weight_arg, - int64_t reduction, - int64_t ignore_index, - const Tensor& total_weight, - bool is2D) { - - if (grad_input_arg.numel() == 0) { - return; - } - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* targetTensor_ = nil; - MPSGraphTensor* weightTensor_ = nil; - MPSGraphTensor* totalWeightTensor_ = nil; - MPSGraphTensor* gradInputTensor_ = nil; - MPSGraphTensor* gradOutputTensor_ = nil; - }; - bool isWeightsArrayValid = weight_arg.defined() && weight_arg.numel() > 0; - int64_t channel_dim = grad_input_arg.dim() < 2 ? 0 : 1; - auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg; - auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg; - auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg; - auto numClasses = grad_input.sizes()[1]; - auto weight = weight_arg; - auto grad_output = grad_output_arg; +void nllnd_loss_backward_impl(Tensor& grad_input_arg, + const Tensor& grad_output_arg, + const Tensor& input_arg, + const Tensor& target_arg, + const Tensor& weight_arg, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight, + bool is2D) { + if (grad_input_arg.numel() == 0) { + return; + } + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* weightTensor_ = nil; + MPSGraphTensor* totalWeightTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; + MPSGraphTensor* gradOutputTensor_ = nil; + }; + bool isWeightsArrayValid = weight_arg.defined() && weight_arg.numel() > 0; + int64_t channel_dim = grad_input_arg.dim() < 2 ? 0 : 1; + auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg; + auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg; + auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg; + auto numClasses = grad_input.sizes()[1]; + auto weight = weight_arg; + auto grad_output = grad_output_arg; + + if (isWeightsArrayValid) { + std::vector weightShape(input.dim(), 1); + weightShape[1] = input.size(1); + weight = weight_arg.view(weightShape); + } + if (grad_output_arg.dim() < grad_input.dim() && grad_output_arg.dim() > 0) { + grad_output = grad_output_arg.unsqueeze(channel_dim); + } + @autoreleasepool { + string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) + + to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + + reductionToString(reduction); - if (isWeightsArrayValid) { - std::vector weightShape(input.dim(), 1); - weightShape[1] = input.size(1); - weight = weight_arg.view(weightShape); - } - if (grad_output_arg.dim() < grad_input.dim() && grad_output_arg.dim() > 0) { - grad_output = grad_output_arg.unsqueeze(channel_dim); - } - @autoreleasepool { - string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) - + to_string(numClasses) + ":" + to_string(ignore_index) + ":" - + to_string(isWeightsArrayValid) + ":" + reductionToString(reduction); - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); - MPSGraphTensor* weightTensor = nil; - if (isWeightsArrayValid) { - weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); - } - MPSGraphTensor* totalWeightTensor = mpsGraphRankedPlaceHolder(mpsGraph, total_weight); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - - MPSGraphTensor *udpatedTargetTensor = targetTensor; - - // Replace ignored_index with length depth + 1 so that oneHotAPI ignores it - if (ignore_index != -100) { - MPSGraphTensor *ignoreIndexTensor = [mpsGraph constantWithScalar: ignore_index - dataType: MPSDataTypeInt64]; - MPSGraphTensor *numClassesTensor = [mpsGraph constantWithScalar: (numClasses + 1) - dataType: MPSDataTypeInt64]; - MPSGraphTensor* isEqualTensor = [mpsGraph equalWithPrimaryTensor: targetTensor - secondaryTensor: ignoreIndexTensor - name: @"isEqualTensor"]; - udpatedTargetTensor = [mpsGraph selectWithPredicateTensor: isEqualTensor - truePredicateTensor: numClassesTensor - falsePredicateTensor: targetTensor - name: @"predicateTensor"]; - } - MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor - depth: numClasses - axis: 1 - dataType: inputTensor.dataType - onValue: -1.0f - offValue: 0.0f - name: nil]; - if (isWeightsArrayValid) { - oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor: oneHotTensor - secondaryTensor: weightTensor - name: @"scaleByWeightTensor"]; - } - if (reduction == Reduction::Mean) { - oneHotTensor = [mpsGraph divisionNoNaNWithPrimaryTensor: oneHotTensor - secondaryTensor: totalWeightTensor - name: @"divisionTensor"]; - } - MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: oneHotTensor - secondaryTensor: gradOutputTensor - name: nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->targetTensor_ = targetTensor; - newCachedGraph->weightTensor_ = weightTensor; - newCachedGraph->totalWeightTensor_ = totalWeightTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - } - return newCachedGraph; - }); - } + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); - auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - auto targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); - Placeholder weightPlaceholder = Placeholder(); - if(isWeightsArrayValid) { - weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + MPSGraphTensor* weightTensor = nil; + if (isWeightsArrayValid) { + weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); + } + MPSGraphTensor* totalWeightTensor = mpsGraphRankedPlaceHolder(mpsGraph, total_weight); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + + MPSGraphTensor* udpatedTargetTensor = targetTensor; + + // Replace ignored_index with length depth + 1 so that oneHotAPI ignores it + if (ignore_index != -100) { + MPSGraphTensor* ignoreIndexTensor = [mpsGraph constantWithScalar:ignore_index dataType:MPSDataTypeInt64]; + MPSGraphTensor* numClassesTensor = [mpsGraph constantWithScalar:(numClasses + 1) dataType:MPSDataTypeInt64]; + MPSGraphTensor* isEqualTensor = [mpsGraph equalWithPrimaryTensor:targetTensor + secondaryTensor:ignoreIndexTensor + name:@"isEqualTensor"]; + udpatedTargetTensor = [mpsGraph selectWithPredicateTensor:isEqualTensor + truePredicateTensor:numClassesTensor + falsePredicateTensor:targetTensor + name:@"predicateTensor"]; + } + MPSGraphTensor* oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor + depth:numClasses + axis:1 + dataType:inputTensor.dataType + onValue:-1.0f + offValue:0.0f + name:nil]; + if (isWeightsArrayValid) { + oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor + secondaryTensor:weightTensor + name:@"scaleByWeightTensor"]; + } + if (reduction == Reduction::Mean) { + oneHotTensor = [mpsGraph divisionNoNaNWithPrimaryTensor:oneHotTensor + secondaryTensor:totalWeightTensor + name:@"divisionTensor"]; + } + MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor + secondaryTensor:gradOutputTensor + name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->targetTensor_ = targetTensor; + newCachedGraph->weightTensor_ = weightTensor; + newCachedGraph->totalWeightTensor_ = totalWeightTensor; + newCachedGraph->gradInputTensor_ = gradInputTensor; + newCachedGraph->gradOutputTensor_ = gradOutputTensor; } - auto totalWeightPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight); - auto gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + return newCachedGraph; + }); + } - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); - feeds[totalWeightPlaceholder.getMPSGraphTensor()] = totalWeightPlaceholder.getMPSGraphTensorData(); - feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); + auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + auto targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); + Placeholder weightPlaceholder = Placeholder(); + if (isWeightsArrayValid) { + weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); + } + auto totalWeightPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight); + auto gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); - if (isWeightsArrayValid) { - feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); - } - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); + feeds[totalWeightPlaceholder.getMPSGraphTensor()] = totalWeightPlaceholder.getMPSGraphTensorData(); + feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + if (isWeightsArrayValid) { + feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); } -} + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; -void nllnd_loss_forward_impl -(Tensor& output, - Tensor& total_weight, - const Tensor& input_arg, - const Tensor& target_arg, - const Tensor& weight, - int64_t reduction, - int64_t ignore_index, - bool is2D) -{ - std::vector reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end()); - reshapedTarget.push_back(1); - - Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1)); - float batchVal = 1.0f; - for(size_t i = 0; i < reshapedTarget.size(); ++i) - batchVal *= reshapedTarget[i]; - batchSizeTensor[0] = batchVal; - - if(reduction == Reduction::None) - output.resize_(target_arg.sizes()); - if(reduction == Reduction::Sum) - output.resize_({}); - if(reduction == Reduction::Mean) - output.resize_({}); - - TORCH_CHECK(output.is_mps()); - - // Empty output - if(output.numel() == 0) - return; - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* targetTensor_ = nil; - MPSGraphTensor* weightTensor_ = nil; - MPSGraphTensor* batchSizeTensor_ = nil; - MPSGraphTensor* totalWeightTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } +} - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); +void nllnd_loss_forward_impl(Tensor& output, + Tensor& total_weight, + const Tensor& input_arg, + const Tensor& target_arg, + const Tensor& weight, + int64_t reduction, + int64_t ignore_index, + bool is2D) { + std::vector reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end()); + reshapedTarget.push_back(1); + + Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1)); + float batchVal = 1.0f; + for (size_t i = 0; i < reshapedTarget.size(); ++i) + batchVal *= reshapedTarget[i]; + batchSizeTensor[0] = batchVal; + + if (reduction == Reduction::None) + output.resize_(target_arg.sizes()); + if (reduction == Reduction::Sum) + output.resize_({}); + if (reduction == Reduction::Mean) + output.resize_({}); - MPSStream* stream = getCurrentMPSStream(); + TORCH_CHECK(output.is_mps()); - auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg; - auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg; + // Empty output + if (output.numel() == 0) + return; - @autoreleasepool { + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* weightTensor_ = nil; + MPSGraphTensor* batchSizeTensor_ = nil; + MPSGraphTensor* totalWeightTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; - bool isWeightsArrayValid = (weight.numel() > 0); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSShape* input_shape = getMPSShape(input); - MPSShape* target_shape = getMPSShape(target); - MPSShape* weight_shape = getMPSShape(weight); + MPSStream* stream = getCurrentMPSStream(); - NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; + auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg; + auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg; - // TODO: Make the key - string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + - to_string(isWeightsArrayValid) + ":" + - reductionToString(reduction) + ":" + - [ns_shape_key UTF8String] + ":" + - getMPSTypeString(input) + ":" + - getMPSTypeString(target) + ":" + - getMPSTypeString(weight); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + @autoreleasepool { + bool isWeightsArrayValid = (weight.numel() > 0); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + MPSShape* input_shape = getMPSShape(input); + MPSShape* target_shape = getMPSShape(target); + MPSShape* weight_shape = getMPSShape(weight); - CachedGraph *newCachedGraph = nil; + NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); + // TODO: Make the key + string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + + getMPSTypeString(target) + ":" + getMPSTypeString(weight); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); - MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape); - MPSGraphTensor* weightTensor = nil; - if(isWeightsArrayValid) - weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(weight), weight_shape); - MPSGraphTensor* mps_batchSizeTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batchSizeTensor)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - MPSGraphTensor* mpsGraphBatchSizeTensor = mps_batchSizeTensor; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - // The transposes are needed to get the class dimension (dim 1) to the inner most dim for gather op. - // The transpose become nop in the 2D case. - MPSGraphTensor* mpsTransposeTensor = inputTensor; - int classDim = 1; - int lastDim = input.sizes().size()-1; - mpsTransposeTensor = [mpsGraph transposeTensor:inputTensor - dimension:classDim - withDimension:lastDim - name:nil]; - for(int i = 0; i < lastDim - 2; ++i) - { - mpsTransposeTensor = [mpsGraph transposeTensor:mpsTransposeTensor - dimension:classDim+i - withDimension:classDim+i+1 name:nil]; - } - - - MPSGraphTensor* mpsGatherTensor = [mpsGraph gatherWithUpdatesTensor:mpsTransposeTensor - indicesTensor:targetTensor - axis:lastDim - batchDimensions:lastDim - name:@"gatherTensor"]; - - bool isIgnoreIndexValid = (ignore_index != -100); - MPSGraphTensor* weightGatherTensor; - - if(isWeightsArrayValid) - { - weightGatherTensor = [mpsGraph gatherWithUpdatesTensor:weightTensor - indicesTensor:targetTensor - axis:0 - batchDimensions:0 - name:@"weightGatherTensor"]; - MPSGraphTensor *mpsGatherCopyTensor = [mpsGraph identityWithTensor:mpsGatherTensor - name:@"identityTensor"]; - mpsGatherTensor = [mpsGraph multiplicationWithPrimaryTensor:weightGatherTensor - secondaryTensor:mpsGatherCopyTensor - name:@"scaledLossTensor"]; - } - - // Both these cases need recomputation of denominator when reductionMode == mean - if(isIgnoreIndexValid || isWeightsArrayValid) - { - // Setup tensors - MPSGraphTensor *mpsGraphZeroTensor = [mpsGraph constantWithScalar:0.0 - dataType:mpsGatherTensor.dataType]; - MPSGraphTensor *mpsGraphOneTensor = [mpsGraph constantWithScalar:1.0 - dataType:mpsGatherTensor.dataType]; - // @TODO: Remove this identity call with ToT StarSky MPSGraph - MPSGraphTensor *mpsGraphOneTensorCopy = [mpsGraph identityWithTensor:mpsGraphOneTensor - name:@"IdentityHackTensor"]; - - MPSGraphTensor *mpsGraphIsEqualTensor; - - if(isIgnoreIndexValid) - { - MPSGraphTensor *mpsGraphIndexTensor = [mpsGraph constantWithScalar:ignore_index - dataType:MPSDataTypeInt64]; - // Equal tensor - mpsGraphIsEqualTensor = [mpsGraph equalWithPrimaryTensor:targetTensor - secondaryTensor:mpsGraphIndexTensor - name:@"isEqualTensor"]; - // Zero out loss - MPSGraphTensor *mpsGatherCopyTensor = [mpsGraph identityWithTensor:mpsGatherTensor - name:@"identityTensor"]; - mpsGatherTensor = [mpsGraph selectWithPredicateTensor:mpsGraphIsEqualTensor - truePredicateTensor:mpsGraphZeroTensor - falsePredicateTensor:mpsGatherCopyTensor - name:@"predicateTensor"]; - } - - if(isWeightsArrayValid) - { - mpsGraphOneTensorCopy = weightGatherTensor; - if(!isIgnoreIndexValid) - { - mpsGraphIsEqualTensor = [mpsGraph constantWithScalar: 0.0 - shape: targetTensor.shape - dataType: targetTensor.dataType]; - } - } - - // Compute new batch size - MPSGraphTensor* mpsSelectOneTensor = [mpsGraph selectWithPredicateTensor:mpsGraphIsEqualTensor - truePredicateTensor:mpsGraphZeroTensor - falsePredicateTensor:mpsGraphOneTensorCopy - name:@"predicateOneTensor"]; - mpsGraphBatchSizeTensor = [mpsGraph reductionSumWithTensor:mpsSelectOneTensor - axes:nil - name:@"batchSizeReductionTensor"]; - } - - MPSGraphTensor *mpsGraphNegTensor = [mpsGraph negativeWithTensor:mpsGatherTensor - name:@"negativeTensor"]; - - MPSGraphTensor* mpsGraphReducedTensor = mpsGraphNegTensor; - - if(!(reduction == Reduction::None)) - { - mpsGraphReducedTensor = [mpsGraph reductionSumWithTensor:mpsGraphNegTensor - axes:nil - name:@"reductionSumTensor"]; - if(reduction == Reduction::Mean) - { - mpsGraphReducedTensor = [mpsGraph divisionNoNaNWithPrimaryTensor:mpsGraphReducedTensor - secondaryTensor:mpsGraphBatchSizeTensor - name:@"divisionTensor"]; - } - } - - mpsGraphReducedTensor = [mpsGraph reshapeTensor:mpsGraphReducedTensor - withShape:getMPSShape(output) - name:nil]; + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); + MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape); + MPSGraphTensor* weightTensor = nil; + if (isWeightsArrayValid) + weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(weight), weight_shape); + MPSGraphTensor* mps_batchSizeTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batchSizeTensor)); + + MPSGraphTensor* mpsGraphBatchSizeTensor = mps_batchSizeTensor; + + // The transposes are needed to get the class dimension (dim 1) to the inner most dim for gather op. + // The transpose become nop in the 2D case. + MPSGraphTensor* mpsTransposeTensor = inputTensor; + int classDim = 1; + int lastDim = input.sizes().size() - 1; + mpsTransposeTensor = [mpsGraph transposeTensor:inputTensor dimension:classDim withDimension:lastDim name:nil]; + for (int i = 0; i < lastDim - 2; ++i) { + mpsTransposeTensor = [mpsGraph transposeTensor:mpsTransposeTensor + dimension:classDim + i + withDimension:classDim + i + 1 + name:nil]; + } - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->targetTensor_ = targetTensor; - newCachedGraph->weightTensor_ = weightTensor; - newCachedGraph->batchSizeTensor_ = mps_batchSizeTensor; - newCachedGraph->totalWeightTensor_ = mpsGraphBatchSizeTensor; - newCachedGraph->outputTensor_ = mpsGraphReducedTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } + MPSGraphTensor* mpsGatherTensor = [mpsGraph gatherWithUpdatesTensor:mpsTransposeTensor + indicesTensor:targetTensor + axis:lastDim + batchDimensions:lastDim + name:@"gatherTensor"]; + + bool isIgnoreIndexValid = (ignore_index != -100); + MPSGraphTensor* weightGatherTensor; + + if (isWeightsArrayValid) { + weightGatherTensor = [mpsGraph gatherWithUpdatesTensor:weightTensor + indicesTensor:targetTensor + axis:0 + batchDimensions:0 + name:@"weightGatherTensor"]; + MPSGraphTensor* mpsGatherCopyTensor = [mpsGraph identityWithTensor:mpsGatherTensor name:@"identityTensor"]; + mpsGatherTensor = [mpsGraph multiplicationWithPrimaryTensor:weightGatherTensor + secondaryTensor:mpsGatherCopyTensor + name:@"scaledLossTensor"]; + } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, input); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); - Placeholder weightPlaceholder = Placeholder(); - if(isWeightsArrayValid) - weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); - Placeholder batchSizePlaceholder = Placeholder(cachedGraph->batchSizeTensor_, batchSizeTensor); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - Placeholder totalWeightsPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight); + // Both these cases need recomputation of denominator when reductionMode == mean + if (isIgnoreIndexValid || isWeightsArrayValid) { + // Setup tensors + MPSGraphTensor* mpsGraphZeroTensor = [mpsGraph constantWithScalar:0.0 dataType:mpsGatherTensor.dataType]; + MPSGraphTensor* mpsGraphOneTensor = [mpsGraph constantWithScalar:1.0 dataType:mpsGatherTensor.dataType]; + // @TODO: Remove this identity call with ToT StarSky MPSGraph + MPSGraphTensor* mpsGraphOneTensorCopy = [mpsGraph identityWithTensor:mpsGraphOneTensor + name:@"IdentityHackTensor"]; + + MPSGraphTensor* mpsGraphIsEqualTensor; + + if (isIgnoreIndexValid) { + MPSGraphTensor* mpsGraphIndexTensor = [mpsGraph constantWithScalar:ignore_index + dataType:MPSDataTypeInt64]; + // Equal tensor + mpsGraphIsEqualTensor = [mpsGraph equalWithPrimaryTensor:targetTensor + secondaryTensor:mpsGraphIndexTensor + name:@"isEqualTensor"]; + // Zero out loss + MPSGraphTensor* mpsGatherCopyTensor = [mpsGraph identityWithTensor:mpsGatherTensor + name:@"identityTensor"]; + mpsGatherTensor = [mpsGraph selectWithPredicateTensor:mpsGraphIsEqualTensor + truePredicateTensor:mpsGraphZeroTensor + falsePredicateTensor:mpsGatherCopyTensor + name:@"predicateTensor"]; + } + + if (isWeightsArrayValid) { + mpsGraphOneTensorCopy = weightGatherTensor; + if (!isIgnoreIndexValid) { + mpsGraphIsEqualTensor = [mpsGraph constantWithScalar:0.0 + shape:targetTensor.shape + dataType:targetTensor.dataType]; + } + } + + // Compute new batch size + MPSGraphTensor* mpsSelectOneTensor = [mpsGraph selectWithPredicateTensor:mpsGraphIsEqualTensor + truePredicateTensor:mpsGraphZeroTensor + falsePredicateTensor:mpsGraphOneTensorCopy + name:@"predicateOneTensor"]; + mpsGraphBatchSizeTensor = [mpsGraph reductionSumWithTensor:mpsSelectOneTensor + axes:nil + name:@"batchSizeReductionTensor"]; + } - // Create dictionary of inputs and outputs - NSMutableDictionary* feeds = [[[NSMutableDictionary alloc] initWithCapacity: 4] autorelease]; - feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); - feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); - feeds[batchSizePlaceholder.getMPSGraphTensor()] = batchSizePlaceholder.getMPSGraphTensorData(); + MPSGraphTensor* mpsGraphNegTensor = [mpsGraph negativeWithTensor:mpsGatherTensor name:@"negativeTensor"]; - if(isWeightsArrayValid) - feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); + MPSGraphTensor* mpsGraphReducedTensor = mpsGraphNegTensor; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), - totalWeightsPlaceholder.getMPSGraphTensor() : totalWeightsPlaceholder.getMPSGraphTensorData() - }; + if (!(reduction == Reduction::None)) { + mpsGraphReducedTensor = [mpsGraph reductionSumWithTensor:mpsGraphNegTensor + axes:nil + name:@"reductionSumTensor"]; + if (reduction == Reduction::Mean) { + mpsGraphReducedTensor = [mpsGraph divisionNoNaNWithPrimaryTensor:mpsGraphReducedTensor + secondaryTensor:mpsGraphBatchSizeTensor + name:@"divisionTensor"]; + } + } - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + mpsGraphReducedTensor = [mpsGraph reshapeTensor:mpsGraphReducedTensor withShape:getMPSShape(output) name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->targetTensor_ = targetTensor; + newCachedGraph->weightTensor_ = weightTensor; + newCachedGraph->batchSizeTensor_ = mps_batchSizeTensor; + newCachedGraph->totalWeightTensor_ = mpsGraphBatchSizeTensor; + newCachedGraph->outputTensor_ = mpsGraphReducedTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); } - return; + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); + Placeholder weightPlaceholder = Placeholder(); + if (isWeightsArrayValid) + weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); + Placeholder batchSizePlaceholder = Placeholder(cachedGraph->batchSizeTensor_, batchSizeTensor); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder totalWeightsPlaceholder = Placeholder(cachedGraph->totalWeightTensor_, total_weight); + + // Create dictionary of inputs and outputs + NSMutableDictionary* feeds = + [[[NSMutableDictionary alloc] initWithCapacity:4] autorelease]; + feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); + feeds[targetPlaceholder.getMPSGraphTensor()] = targetPlaceholder.getMPSGraphTensorData(); + feeds[batchSizePlaceholder.getMPSGraphTensor()] = batchSizePlaceholder.getMPSGraphTensorData(); + + if (isWeightsArrayValid) + feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), + totalWeightsPlaceholder.getMPSGraphTensor() : totalWeightsPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + + return; } -void smooth_l1_loss_impl( - const Tensor &input, - const Tensor &target, - const int64_t reduction, - double beta, - const Tensor &output, - MPSShape *mpsInputShape, - MPSShape *mpsOutputShape) -{ - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *targetTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; +void smooth_l1_loss_impl(const Tensor& input, + const Tensor& target, + const int64_t reduction, + double beta, + const Tensor& output, + MPSShape* mpsInputShape, + MPSShape* mpsOutputShape) { + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSStream *stream= getCurrentMPSStream(); + MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { MPSShape* input_shape = getMPSShape(input); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + - [ns_shape_key UTF8String] + ":" + - to_string(beta) + ":" + - getMPSTypeString(input) + ":" + - getMPSTypeString(target); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + + to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; // smooth_l1_loss_mps: // ln = 0.5 * ( xn - yn ) ^ 2 / beta, if |xn - yn| < beta @@ -723,74 +676,69 @@ void smooth_l1_loss_impl( @autoreleasepool { // Initialize graph - MPSGraph *mpsGraph = make_mps_graph(); + MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input)); - MPSGraphTensor *targetTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(target)); + MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input)); + MPSGraphTensor* targetTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(target)); // Setup tensors - MPSGraphTensor *mpsGraphHalfTensor = [mpsGraph constantWithScalar: 0.5 - dataType: inputTensor.dataType]; - MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta - dataType: inputTensor.dataType]; + MPSGraphTensor* mpsGraphHalfTensor = [mpsGraph constantWithScalar:0.5 dataType:inputTensor.dataType]; + MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:inputTensor.dataType]; // 0.5 * beta - MPSGraphTensor *halfTensorMulBetaTensor = [mpsGraph constantWithScalar: beta * 0.5 - dataType: inputTensor.dataType]; + MPSGraphTensor* halfTensorMulBetaTensor = [mpsGraph constantWithScalar:beta * 0.5 + dataType:inputTensor.dataType]; // Calculating first part of the equation: // ln = 0.5(xn - yn)^2/beta, if |xn - yn| < beta // xn - yn - MPSGraphTensor *diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor - secondaryTensor: targetTensor - name: nil]; + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor + secondaryTensor:targetTensor + name:nil]; // | xn - yn | - MPSGraphTensor *diffAbsTensor = [mpsGraph absoluteWithTensor: diffTensor - name: nil]; + MPSGraphTensor* diffAbsTensor = [mpsGraph absoluteWithTensor:diffTensor name:nil]; // | xn - yn | < beta - MPSGraphTensor *diffAbsLessThanBetaTensor = [mpsGraph lessThanWithPrimaryTensor: diffAbsTensor - secondaryTensor: betaTensor - name: nil]; + MPSGraphTensor* diffAbsLessThanBetaTensor = [mpsGraph lessThanWithPrimaryTensor:diffAbsTensor + secondaryTensor:betaTensor + name:nil]; // ( xn - yn ) ^ 2 - MPSGraphTensor *diffSquare = [mpsGraph squareWithTensor: diffTensor - name: nil]; + MPSGraphTensor* diffSquare = [mpsGraph squareWithTensor:diffTensor name:nil]; // 0.5 * ( xn - yn ) ^ 2 - MPSGraphTensor *diffSquareMulHalfTensor = [mpsGraph multiplicationWithPrimaryTensor: diffSquare - secondaryTensor: mpsGraphHalfTensor - name: nil]; + MPSGraphTensor* diffSquareMulHalfTensor = [mpsGraph multiplicationWithPrimaryTensor:diffSquare + secondaryTensor:mpsGraphHalfTensor + name:nil]; // 0.5 * ( xn - yn ) ^ 2 / beta - MPSGraphTensor *loss1Temp = [mpsGraph divisionWithPrimaryTensor: diffSquareMulHalfTensor - secondaryTensor: betaTensor - name: nil]; + MPSGraphTensor* loss1Temp = [mpsGraph divisionWithPrimaryTensor:diffSquareMulHalfTensor + secondaryTensor:betaTensor + name:nil]; // Calculating second part of the equation: // | xn - yn | - 0.5 * beta, if | xn - yn | >= beta // | xn - yn | - 0.5 * beta - MPSGraphTensor *loss2Temp = [mpsGraph subtractionWithPrimaryTensor: diffAbsTensor - secondaryTensor: halfTensorMulBetaTensor - name: nil]; + MPSGraphTensor* loss2Temp = [mpsGraph subtractionWithPrimaryTensor:diffAbsTensor + secondaryTensor:halfTensorMulBetaTensor + name:nil]; - MPSGraphTensor *lossTensor = [mpsGraph selectWithPredicateTensor: diffAbsLessThanBetaTensor - truePredicateTensor: loss1Temp - falsePredicateTensor: loss2Temp - name: @"lossTensor"]; + MPSGraphTensor* lossTensor = [mpsGraph selectWithPredicateTensor:diffAbsLessThanBetaTensor + truePredicateTensor:loss1Temp + falsePredicateTensor:loss2Temp + name:@"lossTensor"]; - MPSGraphTensor *outputTensor = reduceTensor(lossTensor, reduction, mpsGraph, 1); + MPSGraphTensor* outputTensor = reduceTensor(lossTensor, reduction, mpsGraph, 1); newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->targetTensor_ = targetTensor; newCachedGraph->outputTensor_ = outputTensor; - } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, mpsInputShape); @@ -799,148 +747,128 @@ void smooth_l1_loss_impl( NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - targetPlaceholder.getMPSGraphTensor() : targetPlaceholder .getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -void smooth_l1_loss_template( - const Tensor &input, - const Tensor &target, - const int64_t reduction, - double beta, - const Tensor &output) -{ - TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta."); +void smooth_l1_loss_template(const Tensor& input, + const Tensor& target, + const int64_t reduction, + double beta, + const Tensor& output) { + TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta."); TORCH_CHECK(input.is_mps()); TORCH_CHECK(target.is_mps()); - MPSShape *mpsInputShape = nil; - MPSShape *mpsOutputShape = nil; + MPSShape* mpsInputShape = nil; + MPSShape* mpsOutputShape = nil; // Determine the shape of the output // If the reduction is 'mean' or 'sum', the output shape is a scalar, // otherwise, the output shape is the same shape as input - if (reduction == Reduction::Mean || reduction == Reduction::Sum) - { - // Output: scalar, if reduction is 'mean' or 'sum' - IntArrayRef input_shape = input.sizes(); - int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; - int64_t num_in_elements = 1; - for(int i = 0; i < num_input_dims; i++) { - num_in_elements *= input_shape[i]; - } - apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements]; - - // Output is a single value in case reduction is set to mean or sum - NSMutableArray *apparent_out_shape = [NSMutableArray arrayWithCapacity:1]; - apparent_out_shape[0] = @1; - mpsInputShape = apparent_input_shape; - mpsOutputShape = apparent_out_shape; - } - else - { - // Output: If reduction is 'none', then (N, *); same shape as the input - assert(reduction == Reduction::None); - mpsInputShape = getMPSShape(input); - mpsOutputShape = mpsInputShape; - //resize_tensor(&output); + if (reduction == Reduction::Mean || reduction == Reduction::Sum) { + // Output: scalar, if reduction is 'mean' or 'sum' + IntArrayRef input_shape = input.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray* apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; + int64_t num_in_elements = 1; + for (int i = 0; i < num_input_dims; i++) { + num_in_elements *= input_shape[i]; + } + apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements]; + + // Output is a single value in case reduction is set to mean or sum + NSMutableArray* apparent_out_shape = [NSMutableArray arrayWithCapacity:1]; + apparent_out_shape[0] = @1; + mpsInputShape = apparent_input_shape; + mpsOutputShape = apparent_out_shape; + } else { + // Output: If reduction is 'none', then (N, *); same shape as the input + assert(reduction == Reduction::None); + mpsInputShape = getMPSShape(input); + mpsOutputShape = mpsInputShape; + // resize_tensor(&output); } TORCH_CHECK(output.is_mps()); - smooth_l1_loss_impl( - input, - target, - reduction, - beta, - output, - mpsInputShape, - mpsOutputShape); + smooth_l1_loss_impl(input, target, reduction, beta, output, mpsInputShape, mpsOutputShape); } -void smooth_l1_loss_backward_impl( - const Tensor& grad_output, - const Tensor& input, - const Tensor& target, - int64_t reduction, - double beta, - Tensor& grad_input) -{ +void smooth_l1_loss_backward_impl(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double beta, + Tensor& grad_input) { if (grad_input.numel() == 0) { return; } TORCH_CHECK(beta >= 0, "smooth_l1_loss_backward does not support negative values for beta."); struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *targetTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; MPSGraphTensor* gradOutputTensor_ = nil; }; @autoreleasepool { - string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" - + reductionToString(reduction) + ":" + to_string(beta); + string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" + + reductionToString(reduction) + ":" + to_string(beta); - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph *mpsGraph = make_mps_graph(); + MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor *targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); - MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta - dataType: MPSDataTypeFloat32]; + MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32]; // xn - yn - MPSGraphTensor *diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor - secondaryTensor: targetTensor - name: nil]; + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor + secondaryTensor:targetTensor + name:nil]; // | xn - yn | - MPSGraphTensor *diffAbsTensor = [mpsGraph absoluteWithTensor: diffTensor - name: nil]; + MPSGraphTensor* diffAbsTensor = [mpsGraph absoluteWithTensor:diffTensor name:nil]; // | xn - yn | < beta - MPSGraphTensor *diffAbsLessThanBetaTensor = [mpsGraph lessThanWithPrimaryTensor: diffAbsTensor - secondaryTensor: betaTensor - name: nil]; + MPSGraphTensor* diffAbsLessThanBetaTensor = [mpsGraph lessThanWithPrimaryTensor:diffAbsTensor + secondaryTensor:betaTensor + name:nil]; // ( xn - yn ) / beta - MPSGraphTensor *truePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor - secondaryTensor: betaTensor - name: nil]; + MPSGraphTensor* truePredicateTensor = [mpsGraph divisionWithPrimaryTensor:diffTensor + secondaryTensor:betaTensor + name:nil]; // ( x - y ) / | x - y | - MPSGraphTensor *falsePredicateTensor = [mpsGraph divisionWithPrimaryTensor: diffTensor - secondaryTensor: diffAbsTensor - name: nil]; - - MPSGraphTensor *lossTensor = [mpsGraph selectWithPredicateTensor: diffAbsLessThanBetaTensor - truePredicateTensor: truePredicateTensor - falsePredicateTensor: falsePredicateTensor - name: @"lossTensor"]; - MPSGraphTensor *outputTensor = lossTensor; + MPSGraphTensor* falsePredicateTensor = [mpsGraph divisionWithPrimaryTensor:diffTensor + secondaryTensor:diffAbsTensor + name:nil]; + + MPSGraphTensor* lossTensor = [mpsGraph selectWithPredicateTensor:diffAbsLessThanBetaTensor + truePredicateTensor:truePredicateTensor + falsePredicateTensor:falsePredicateTensor + name:@"lossTensor"]; + MPSGraphTensor* outputTensor = lossTensor; if (reduction == Reduction::Mean) { - MPSGraphTensor *numelTensor = [mpsGraph constantWithScalar: (double) input.numel() - dataType: MPSDataTypeFloat32]; - outputTensor = [mpsGraph divisionWithPrimaryTensor: lossTensor - secondaryTensor: numelTensor - name: nil]; + MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() + dataType:MPSDataTypeFloat32]; + outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil]; } - MPSGraphTensor *gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor: outputTensor - secondaryTensor: gradOutputTensor - name: nil]; + MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor + secondaryTensor:gradOutputTensor + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->targetTensor_ = targetTensor; newCachedGraph->gradInputTensor_ = gradInputTensor; @@ -959,9 +887,8 @@ void smooth_l1_loss_backward_impl( targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData(), gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } @@ -973,362 +900,343 @@ void smooth_l1_loss_backward_impl( // HuberLoss -Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output){ - string op_name = __func__; - using namespace mps; - TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") - TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") - TORCH_CHECK(output.is_mps()); - - if(reduction == Reduction::None) - output.resize_(target.sizes()); - if(reduction == Reduction::Sum) - output.resize_({}); - if(reduction == Reduction::Mean) - output.resize_({}); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* targetTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); +Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) { + string op_name = __func__; + using namespace mps; + TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") + TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") + TORCH_CHECK(output.is_mps()); - @autoreleasepool { - string key = op_name + ":" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" + getTensorsStringKey({input, target}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); - - MPSDataType input_type = getMPSScalarType(input.scalar_type()); - MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta - shape:@[@1] - dataType:input_type]; - MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:.5f - shape:@[@1] - dataType:input_type]; - - MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor - secondaryTensor: targetTensor - name: nil]; - MPSGraphTensor* absDiffTensor = [mpsGraph absoluteWithTensor: diffTensor - name: nil]; - MPSGraphTensor* firstCondTensor = [mpsGraph multiplicationWithPrimaryTensor: absDiffTensor - secondaryTensor: absDiffTensor - name: nil]; - firstCondTensor = [mpsGraph multiplicationWithPrimaryTensor: firstCondTensor - secondaryTensor: halfTensor - name: nil]; - MPSGraphTensor* secondCondTensor = [mpsGraph multiplicationWithPrimaryTensor: deltaTensor - secondaryTensor: halfTensor - name: nil]; - secondCondTensor = [mpsGraph subtractionWithPrimaryTensor: absDiffTensor - secondaryTensor: secondCondTensor - name: nil]; - secondCondTensor = [mpsGraph multiplicationWithPrimaryTensor: deltaTensor - secondaryTensor: secondCondTensor - name: nil]; - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph lessThanOrEqualToWithPrimaryTensor:absDiffTensor - secondaryTensor:deltaTensor - name:nil] - truePredicateTensor: firstCondTensor - falsePredicateTensor: secondCondTensor - name:nil]; + if (reduction == Reduction::None) + output.resize_(target.sizes()); + if (reduction == Reduction::Sum) + output.resize_({}); + if (reduction == Reduction::Mean) + output.resize_({}); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->targetTensor_ = targetTensor; - newCachedGraph->outputTensor_ = reduceTensor(outputTensor, reduction, mpsGraph, input.sizes().size()); - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); + @autoreleasepool { + string key = op_name + ":" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" + + getTensorsStringKey({input, target}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + + MPSDataType input_type = getMPSScalarType(input.scalar_type()); + MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta shape:@[ @1 ] dataType:input_type]; + MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:.5f shape:@[ @1 ] dataType:input_type]; + + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor + secondaryTensor:targetTensor + name:nil]; + MPSGraphTensor* absDiffTensor = [mpsGraph absoluteWithTensor:diffTensor name:nil]; + MPSGraphTensor* firstCondTensor = [mpsGraph multiplicationWithPrimaryTensor:absDiffTensor + secondaryTensor:absDiffTensor + name:nil]; + firstCondTensor = [mpsGraph multiplicationWithPrimaryTensor:firstCondTensor + secondaryTensor:halfTensor + name:nil]; + MPSGraphTensor* secondCondTensor = [mpsGraph multiplicationWithPrimaryTensor:deltaTensor + secondaryTensor:halfTensor + name:nil]; + secondCondTensor = [mpsGraph subtractionWithPrimaryTensor:absDiffTensor + secondaryTensor:secondCondTensor + name:nil]; + secondCondTensor = [mpsGraph multiplicationWithPrimaryTensor:deltaTensor + secondaryTensor:secondCondTensor + name:nil]; + MPSGraphTensor* outputTensor = + [mpsGraph selectWithPredicateTensor:[mpsGraph lessThanOrEqualToWithPrimaryTensor:absDiffTensor + secondaryTensor:deltaTensor + name:nil] + truePredicateTensor:firstCondTensor + falsePredicateTensor:secondCondTensor + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->targetTensor_ = targetTensor; + newCachedGraph->outputTensor_ = reduceTensor(outputTensor, reduction, mpsGraph, input.sizes().size()); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); } - return output; + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + return output; } Tensor huber_loss_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta) { TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta."); - Tensor output = at::native::empty_mps( - input.sizes(), - input.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + Tensor output = + at::native::empty_mps(input.sizes(), input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); return huber_loss_out_mps(input, target, reduction, delta, output); } -Tensor& huber_loss_backward_out_mps( - const Tensor& grad_output, - const Tensor& input, - const Tensor& target, - int64_t reduction, - double delta, - Tensor& grad_input) -{ - using namespace mps; - auto is_mean_reduction = reduction == Reduction::Mean; - auto input_numel = input.numel(); - - auto new_grad_output = grad_output.contiguous(); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *targetTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; +Tensor& huber_loss_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta, + Tensor& grad_input) { + using namespace mps; + auto is_mean_reduction = reduction == Reduction::Mean; + auto input_numel = input.numel(); - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); - - MPSStream *stream= getCurrentMPSStream(); - - @autoreleasepool { - MPSShape* input_shape = getMPSShape(input); - NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - - string key = "huber_loss_backward_out_mps:" + reductionToString(reduction) + ":" + - std::to_string(delta) + ":" + - [ns_shape_key UTF8String] + ":" + - getMPSTypeString(input) + ":" + - getMPSTypeString(target); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - // Initialize graph - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(new_grad_output), getMPSShape(new_grad_output)); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); - MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), getMPSShape(target)); - MPSGraphTensor* isMeanReductionTensor = [mpsGraph constantWithScalar:is_mean_reduction - dataType:MPSDataTypeInt64]; // constant does not support MPSDataTypeBool - MPSGraphTensor* inputNumelTensor = [mpsGraph constantWithScalar:input_numel - dataType:getMPSDataType(new_grad_output)]; - - MPSGraphTensor* normGradOutputTensor = [mpsGraph selectWithPredicateTensor:isMeanReductionTensor - truePredicateTensor: [mpsGraph divisionWithPrimaryTensor:gradOutputTensor - secondaryTensor:inputNumelTensor - name:nil] - falsePredicateTensor: gradOutputTensor - name:nil]; - MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta - shape:getMPSShape(target) - dataType:getMPSDataType(target)]; - MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor - secondaryTensor:targetTensor - name:nil]; - MPSGraphTensor* normGradOutputDeltaTensor = [mpsGraph multiplicationWithPrimaryTensor:normGradOutputTensor - secondaryTensor:deltaTensor - name:nil]; - // first condition: (input - target) <= -delta - // formula: -norm * grad_output * delta - MPSGraphTensor* firstCondTensor = [mpsGraph negativeWithTensor: normGradOutputDeltaTensor - name: nil]; - // second condition: (input - target) >= delta - // formula: norm * grad_output * delta - MPSGraphTensor* secondCondTensor = normGradOutputDeltaTensor; - - // third condition: (input - target) within -delta to delta - // formula: norm * (input - target) * grad_output - MPSGraphTensor* thirdCondTensor = [mpsGraph multiplicationWithPrimaryTensor:normGradOutputTensor - secondaryTensor:diffTensor - name:nil]; + auto new_grad_output = grad_output.contiguous(); - MPSGraphTensor* secondThirdTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph greaterThanOrEqualToWithPrimaryTensor:diffTensor - secondaryTensor:deltaTensor - name:nil] - truePredicateTensor: secondCondTensor - falsePredicateTensor: thirdCondTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph lessThanOrEqualToWithPrimaryTensor: diffTensor - secondaryTensor:[mpsGraph negativeWithTensor: deltaTensor - name: nil] - name:nil] - truePredicateTensor: firstCondTensor - falsePredicateTensor: secondThirdTensor - name:nil]; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* targetTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->targetTensor_ = targetTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - })); - } + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, new_grad_output); - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); - - NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + @autoreleasepool { + MPSShape* input_shape = getMPSShape(input); + NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; + + string key = "huber_loss_backward_out_mps:" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" + + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + // Initialize graph + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* gradOutputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(new_grad_output), getMPSShape(new_grad_output)); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); + MPSGraphTensor* targetTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), getMPSShape(target)); + MPSGraphTensor* isMeanReductionTensor = + [mpsGraph constantWithScalar:is_mean_reduction + dataType:MPSDataTypeInt64]; // constant does not support MPSDataTypeBool + MPSGraphTensor* inputNumelTensor = [mpsGraph constantWithScalar:input_numel + dataType:getMPSDataType(new_grad_output)]; + + MPSGraphTensor* normGradOutputTensor = + [mpsGraph selectWithPredicateTensor:isMeanReductionTensor + truePredicateTensor:[mpsGraph divisionWithPrimaryTensor:gradOutputTensor + secondaryTensor:inputNumelTensor + name:nil] + falsePredicateTensor:gradOutputTensor + name:nil]; + MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta + shape:getMPSShape(target) + dataType:getMPSDataType(target)]; + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor + secondaryTensor:targetTensor + name:nil]; + MPSGraphTensor* normGradOutputDeltaTensor = [mpsGraph multiplicationWithPrimaryTensor:normGradOutputTensor + secondaryTensor:deltaTensor + name:nil]; + // first condition: (input - target) <= -delta + // formula: -norm * grad_output * delta + MPSGraphTensor* firstCondTensor = [mpsGraph negativeWithTensor:normGradOutputDeltaTensor name:nil]; + // second condition: (input - target) >= delta + // formula: norm * grad_output * delta + MPSGraphTensor* secondCondTensor = normGradOutputDeltaTensor; + + // third condition: (input - target) within -delta to delta + // formula: norm * (input - target) * grad_output + MPSGraphTensor* thirdCondTensor = [mpsGraph multiplicationWithPrimaryTensor:normGradOutputTensor + secondaryTensor:diffTensor + name:nil]; + + MPSGraphTensor* secondThirdTensor = + [mpsGraph selectWithPredicateTensor:[mpsGraph greaterThanOrEqualToWithPrimaryTensor:diffTensor + secondaryTensor:deltaTensor + name:nil] + truePredicateTensor:secondCondTensor + falsePredicateTensor:thirdCondTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph + selectWithPredicateTensor:[mpsGraph + lessThanOrEqualToWithPrimaryTensor:diffTensor + secondaryTensor:[mpsGraph negativeWithTensor:deltaTensor + name:nil] + name:nil] + truePredicateTensor:firstCondTensor + falsePredicateTensor:secondThirdTensor + name:nil]; + + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->targetTensor_ = targetTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + })); } - return grad_input; + + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, new_grad_output); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor_, target); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input); + + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return grad_input; } // MSELoss -TORCH_IMPL_FUNC(mse_loss_out_mps) ( - const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& output) { - string op_name = __func__; - using namespace mps; - TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") - TORCH_CHECK(output.is_mps()); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor = nil; - MPSGraphTensor* targetTensor = nil; - MPSGraphTensor* outputTensor = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); +TORCH_IMPL_FUNC(mse_loss_out_mps)(const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& output) { + string op_name = __func__; + using namespace mps; + TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") + TORCH_CHECK(output.is_mps()); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor = nil; + MPSGraphTensor* targetTensor = nil; + MPSGraphTensor* outputTensor = nil; + }; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); - - MPSGraphTensor *diffTensor = [mpsGraph subtractionWithPrimaryTensor: newCachedGraph->inputTensor - secondaryTensor: newCachedGraph->targetTensor - name: nil]; - MPSGraphTensor *diffSquareTensor = [mpsGraph squareWithTensor: diffTensor - name: nil]; - newCachedGraph->outputTensor = reduceTensor(diffSquareTensor, reduction, mpsGraph, input.sizes().size()); - } - return newCachedGraph; - })); + @autoreleasepool { + string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + + MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:newCachedGraph->inputTensor + secondaryTensor:newCachedGraph->targetTensor + name:nil]; + MPSGraphTensor* diffSquareTensor = [mpsGraph squareWithTensor:diffTensor name:nil]; + newCachedGraph->outputTensor = reduceTensor(diffSquareTensor, reduction, mpsGraph, input.sizes().size()); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); - - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + return newCachedGraph; + })); } + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + targetPlaceholder.getMPSGraphTensor() : targetPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } } -Tensor& mse_loss_backward_out_mps(const Tensor& grad_output, const Tensor& input, - const Tensor& target, int64_t reduction, Tensor& grad_input) -{ - return mps::mse_loss_backward_out_impl(grad_output, input, target, reduction, grad_input, __func__); +Tensor& mse_loss_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + Tensor& grad_input) { + return mps::mse_loss_backward_out_impl(grad_output, input, target, reduction, grad_input, __func__); } -Tensor mse_loss_backward_mps(const Tensor& grad_output, const Tensor& input, - const Tensor& target, int64_t reduction) -{ - Tensor grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - return mps::mse_loss_backward_out_impl(grad_output, input, target, reduction, grad_input, __func__); +Tensor mse_loss_backward_mps(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) { + Tensor grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + return mps::mse_loss_backward_out_impl(grad_output, input, target, reduction, grad_input, __func__); } // BCELoss -Tensor& binary_cross_entropy_out_mps(const Tensor& input, const Tensor& target, - const c10::optional& weight_opt, int64_t reduction, Tensor& loss) -{ - return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, loss, c10::nullopt, __func__); +Tensor& binary_cross_entropy_out_mps(const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + Tensor& loss) { + return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, loss, c10::nullopt, __func__); } -Tensor binary_cross_entropy_mps(const Tensor& input, const Tensor& target, - const c10::optional& weight_opt, int64_t reduction) -{ - Tensor loss = at::empty_like(input); - return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, loss, c10::nullopt, __func__); +Tensor binary_cross_entropy_mps(const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction) { + Tensor loss = at::empty_like(input); + return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, loss, c10::nullopt, __func__); } -Tensor& binary_cross_entropy_backward_out_mps(const Tensor& grad_output, const Tensor& input, - const Tensor& target, const c10::optional& weight_opt, - int64_t reduction, Tensor& grad_input) -{ - return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, grad_input, grad_output, __func__); +Tensor& binary_cross_entropy_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + Tensor& grad_input) { + return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, grad_input, grad_output, __func__); } -Tensor binary_cross_entropy_backward_mps(const Tensor& grad_output, const Tensor& input, const Tensor& target, - const c10::optional& weight_opt, int64_t reduction) -{ - Tensor grad_input = at::empty_like(input); - return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, grad_input, grad_output, __func__); +Tensor binary_cross_entropy_backward_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction) { + Tensor grad_input = at::empty_like(input); + return mps::BCELoss::bce_loss_out_impl(input, target, weight_opt, reduction, grad_input, grad_output, __func__); } // SmoothL1Loss -TORCH_IMPL_FUNC(smooth_l1_loss_out_mps)( - const Tensor& input, - const Tensor& target, - int64_t reduction, - double beta, - const Tensor& result) { - mps::smooth_l1_loss_template( - input, target, reduction, beta, result); +TORCH_IMPL_FUNC(smooth_l1_loss_out_mps) +(const Tensor& input, const Tensor& target, int64_t reduction, double beta, const Tensor& result) { + mps::smooth_l1_loss_template(input, target, reduction, beta, result); } -Tensor& smooth_l1_loss_backward_out_mps( - const Tensor& grad_output, - const Tensor& input, - const Tensor& target, - int64_t reduction, - double beta, - Tensor& grad_input) { - - mps::smooth_l1_loss_backward_impl( - grad_output, input, target, reduction, beta, grad_input); +Tensor& smooth_l1_loss_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double beta, + Tensor& grad_input) { + mps::smooth_l1_loss_backward_impl(grad_output, input, target, reduction, beta, grad_input); return grad_input; } @@ -1342,21 +1250,12 @@ Tensor binary_cross_entropy_backward_mps(const Tensor& grad_output, const Tensor int64_t reduction, int64_t ignore_index, const Tensor& total_weight, - const Tensor& grad_input -) -{ - const Tensor& weight = weight_opt.getTensorRef(); - - mps::nllnd_loss_backward_impl((Tensor &)grad_input, - grad_output, - self, - target, - weight, - reduction, - ignore_index, - total_weight, - false); - return; + const Tensor& grad_input) { + const Tensor& weight = weight_opt.getTensorRef(); + + mps::nllnd_loss_backward_impl( + (Tensor&)grad_input, grad_output, self, target, weight, reduction, ignore_index, total_weight, false); + return; } TORCH_IMPL_FUNC(nll_loss_forward_out_mps) @@ -1367,38 +1266,25 @@ Tensor binary_cross_entropy_backward_mps(const Tensor& grad_output, const Tensor int64_t ignore_index, const Tensor& output, const Tensor& total_weight) { + const Tensor& weight = weight_opt.getTensorRef(); - const Tensor& weight = weight_opt.getTensorRef(); - - mps::nllnd_loss_forward_impl((Tensor &)output, - (Tensor &)total_weight, - self, - target, - weight, - reduction, - ignore_index, - false); + mps::nllnd_loss_forward_impl( + (Tensor&)output, (Tensor&)total_weight, self, target, weight, reduction, ignore_index, false); - return; + return; } -inline void check_inputs_nll_loss2d( - const Tensor& input, - const Tensor& target, - const Tensor& weight) { - TORCH_CHECK( - target.dim() == 3, - "only batches of spatial targets supported (3D tensors)" - " but got targets of dimension: ", - target.dim()); - TORCH_CHECK( - input.dim() == 4, - "only batches of spatial inputs supported (4D tensors), " - "but got input of dimension: ", - input.dim()); - TORCH_CHECK( - !weight.defined() || weight.numel() == input.size(1), - "weight tensor should be defined either for all or no classes"); +inline void check_inputs_nll_loss2d(const Tensor& input, const Tensor& target, const Tensor& weight) { + TORCH_CHECK(target.dim() == 3, + "only batches of spatial targets supported (3D tensors)" + " but got targets of dimension: ", + target.dim()); + TORCH_CHECK(input.dim() == 4, + "only batches of spatial inputs supported (4D tensors), " + "but got input of dimension: ", + input.dim()); + TORCH_CHECK(!weight.defined() || weight.numel() == input.size(1), + "weight tensor should be defined either for all or no classes"); const int64_t input0 = input.size(0); const int64_t input2 = input.size(2); @@ -1406,150 +1292,114 @@ inline void check_inputs_nll_loss2d( const int64_t target0 = target.size(0); const int64_t target1 = target.size(1); const int64_t target2 = target.size(2); - TORCH_CHECK( - input0 == target0 && input2 == target1 && input3 == target2, - "size mismatch (got input: ", - input.sizes(), - " , target: ", - target.sizes()); + TORCH_CHECK(input0 == target0 && input2 == target1 && input3 == target2, + "size mismatch (got input: ", + input.sizes(), + " , target: ", + target.sizes()); } - -void nll_loss2d_forward_out_mps_template( - Tensor& output, - Tensor& total_weight, - const Tensor& input, - const Tensor& target, - const Tensor& weight, - int64_t reduction, - int64_t ignore_index) { +void nll_loss2d_forward_out_mps_template(Tensor& output, + Tensor& total_weight, + const Tensor& input, + const Tensor& target, + const Tensor& weight, + int64_t reduction, + int64_t ignore_index) { check_inputs_nll_loss2d(input, target, weight); total_weight.resize_({}); - mps::nllnd_loss_forward_impl(output, - total_weight, - input, - target, - weight, - reduction, - ignore_index, - true); + mps::nllnd_loss_forward_impl(output, total_weight, input, target, weight, reduction, ignore_index, true); - return; + return; } std::tuple nll_loss2d_forward_out_mps(const Tensor& self, - const Tensor& target, const c10::optional& weight_opt, - int64_t reduction, - int64_t ignore_index, - Tensor& output, - Tensor& total_weight) { + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + Tensor& output, + Tensor& total_weight) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - nll_loss2d_forward_out_mps_template( - output, total_weight, self, target, weight, reduction, ignore_index); + nll_loss2d_forward_out_mps_template(output, total_weight, self, target, weight, reduction, ignore_index); return std::tuple(output, total_weight); } -std::tuple nll_loss2d_forward_mps( - const Tensor& self, - const Tensor& target, const c10::optional& weight_opt, - int64_t reduction, - int64_t ignore_index) { +std::tuple nll_loss2d_forward_mps(const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; auto output = at::empty({0}, self.options()); auto total_weight = at::empty({0}, self.options()); - at::native::nll_loss2d_forward_out_mps( - self, target, weight, reduction, ignore_index, output, total_weight); + at::native::nll_loss2d_forward_out_mps(self, target, weight, reduction, ignore_index, output, total_weight); return std::make_tuple(output, total_weight); } -void nll_loss2d_backward_out_mps_template( - Tensor& grad_input, - const Tensor& grad_output, - const Tensor& input, - const Tensor& target, - const Tensor& weight, - int64_t reduction, - int64_t ignore_index, - const Tensor& total_weight) { +void nll_loss2d_backward_out_mps_template(Tensor& grad_input, + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + const Tensor& weight, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight) { check_inputs_nll_loss2d(input, target, weight); grad_input.resize_as_(input); grad_input.zero_(); TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); - TORCH_CHECK( - total_weight.numel() == 1, - "expected total_weight to be a single element tensor, got: ", - total_weight.sizes(), - " (", - total_weight.numel(), - " elements)"); - - mps::nllnd_loss_backward_impl(grad_input, - grad_output, - input, - target, - weight, - reduction, - ignore_index, - total_weight, - true); + TORCH_CHECK(total_weight.numel() == 1, + "expected total_weight to be a single element tensor, got: ", + total_weight.sizes(), + " (", + total_weight.numel(), + " elements)"); - return; + mps::nllnd_loss_backward_impl( + grad_input, grad_output, input, target, weight, reduction, ignore_index, total_weight, true); + + return; } Tensor& nll_loss2d_backward_out_mps(const Tensor& grad_output, - const Tensor& self, - const Tensor& target, const c10::optional& weight_opt, - int64_t reduction, - int64_t ignore_index, - const Tensor& total_weight, - Tensor& grad_input) { + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight, + Tensor& grad_input) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; nll_loss2d_backward_out_mps_template( - grad_input, - grad_output, - self, - target, - weight, - reduction, - ignore_index, - total_weight); + grad_input, grad_output, self, target, weight, reduction, ignore_index, total_weight); return grad_input; } -Tensor nll_loss2d_backward_mps( - const Tensor& grad_output, - const Tensor& self, - const Tensor& target, const c10::optional& weight_opt, - int64_t reduction, - int64_t ignore_index, - const Tensor& total_weight) { - +Tensor nll_loss2d_backward_mps(const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + int64_t ignore_index, + const Tensor& total_weight) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; auto grad_input = at::zeros_like(self); - nll_loss2d_backward_out_mps( - grad_output, - self, - target, - weight, - reduction, - ignore_index, - total_weight, - grad_input); + nll_loss2d_backward_out_mps(grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); return grad_input; } - } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index ef557933fbb058..7a1f61bee32ec2 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -2,76 +2,73 @@ #include #include -#include #include +#include #include -#include #include #include +#include #include namespace at::native { void get_shapes(MPSShape* input_shape_readonly, - NSMutableArray* &input_shape, - NSMutableArray* &new_mean_shape, - NSMutableArray* &axes, - int num_input_dims, c10::MemoryFormat memory_format, + NSMutableArray*& input_shape, + NSMutableArray*& new_mean_shape, + NSMutableArray*& axes, + int num_input_dims, + c10::MemoryFormat memory_format, bool isBackward) { // Modify the shape - if(memory_format == at::MemoryFormat::Contiguous) { - for(int i = 0; i < num_input_dims; i++) + if (memory_format == at::MemoryFormat::Contiguous) { + for (int i = 0; i < num_input_dims; i++) input_shape[i] = input_shape_readonly[i]; - } - else { // ChannelsLast + } else { // ChannelsLast auto num_channels = input_shape_readonly[1]; input_shape[0] = input_shape_readonly[0]; - for(int i = 1; i < num_input_dims-1; i++) - input_shape[i] = input_shape_readonly[i+1]; - input_shape[num_input_dims-1] = num_channels; + for (int i = 1; i < num_input_dims - 1; i++) + input_shape[i] = input_shape_readonly[i + 1]; + input_shape[num_input_dims - 1] = num_channels; } // Mean shape should remain unchanged in backward - if(memory_format == at::MemoryFormat::Contiguous || isBackward) { + if (memory_format == at::MemoryFormat::Contiguous || isBackward) { new_mean_shape[0] = @1; new_mean_shape[1] = input_shape_readonly[1]; - for(int i = 2; i < num_input_dims; i++) + for (int i = 2; i < num_input_dims; i++) new_mean_shape[i] = @1; - } - else if(memory_format == at::MemoryFormat::ChannelsLast) { - for(int i = 0; i < num_input_dims-1; i++) + } else if (memory_format == at::MemoryFormat::ChannelsLast) { + for (int i = 0; i < num_input_dims - 1; i++) new_mean_shape[i] = @1; - new_mean_shape[num_input_dims-1] = input_shape[num_input_dims-1]; + new_mean_shape[num_input_dims - 1] = input_shape[num_input_dims - 1]; } // Set axes of reduction - if(memory_format == at::MemoryFormat::Contiguous || isBackward) { - axes[0] = @0; - for(int i = 2; i < num_input_dims; i++) - axes[i-1] = [NSNumber numberWithInt:i]; - } - else { - for(int i = 0; i < num_input_dims-1; i++) - axes[i] = [NSNumber numberWithInt:i]; - } + if (memory_format == at::MemoryFormat::Contiguous || isBackward) { + axes[0] = @0; + for (int i = 2; i < num_input_dims; i++) + axes[i - 1] = [NSNumber numberWithInt:i]; + } else { + for (int i = 0; i < num_input_dims - 1; i++) + axes[i] = [NSNumber numberWithInt:i]; + } } // Inverse standard deviation now becomes variance (without epsilon) -std::tuple batch_norm_mps_out - (const Tensor& self, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - const c10::optional& running_mean_opt, - const c10::optional& running_var_opt, - bool train, double momentum, double epsilon, - Tensor& output, - Tensor& save_mean, - Tensor& save_var) { - +std::tuple batch_norm_mps_out(const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + bool train, + double momentum, + double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_var) { namespace native_mps = at::native::mps; - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* weightTensor_ = nil; MPSGraphTensor* biasTensor_ = nil; @@ -98,13 +95,13 @@ void get_shapes(MPSShape* input_shape_readonly, auto memory_format = self.suggest_memory_format(); if (output.numel() == 0) { - return std::tuple(output, save_mean, save_var);; + return std::tuple(output, save_mean, save_var); + ; } @autoreleasepool { - string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -124,25 +121,26 @@ void get_shapes(MPSShape* input_shape_readonly, // Shape which can be broadcasted with input NSMutableArray* new_mean_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; // Reduction axes - NSMutableArray* axes = [NSMutableArray arrayWithCapacity:(num_input_dims-1)]; + NSMutableArray* axes = [NSMutableArray arrayWithCapacity:(num_input_dims - 1)]; get_shapes(input_shape_readonly, input_shape, new_mean_shape, axes, num_input_dims, memory_format, false); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "batch_norm_mps_out:" + mem_format_key + ":" + std::to_string(epsilon) + ":" - + std::to_string(momentum) + ":" + std::to_string(train) + ":" - + std::to_string(has_running_mean) + ":" - + std::to_string(has_weight) + ":" + std::to_string(has_bias) + ":" - + [ns_shape_key UTF8String] + ":" - + native_mps::getTensorsStringKey({ - self, weight_opt.value_or(Tensor()), bias_opt.value_or(Tensor()), running_mean_opt.value_or(Tensor()), running_var_opt.value_or(Tensor())}); + string key = "batch_norm_mps_out:" + mem_format_key + ":" + std::to_string(epsilon) + ":" + + std::to_string(momentum) + ":" + std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" + + std::to_string(has_weight) + ":" + std::to_string(has_bias) + ":" + [ns_shape_key UTF8String] + ":" + + native_mps::getTensorsStringKey({self, + weight_opt.value_or(Tensor()), + bias_opt.value_or(Tensor()), + running_mean_opt.value_or(Tensor()), + running_var_opt.value_or(Tensor())}); auto input_mps_dtype = native_mps::getMPSDataType(self); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); // Dim where channels are located int channelsDim; - if(memory_format == at::MemoryFormat::Contiguous) + if (memory_format == at::MemoryFormat::Contiguous) channelsDim = 1; else channelsDim = num_input_dims - 1; @@ -153,127 +151,121 @@ void get_shapes(MPSShape* input_shape_readonly, executeGatherOp = false; } - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); + newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape); - MPSGraphTensor* weightTensor = nil; - // Should have shape of mean - if(has_weight) - weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape); - MPSGraphTensor* biasTensor = nil; - if(has_bias) - biasTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(bias_opt.value()), new_mean_shape); - MPSGraphTensor* runningMeanTensor = nil; - MPSGraphTensor* runningVarTensor = nil; - if(has_running_mean) { - runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape); - runningVarTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape); - } + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape); + MPSGraphTensor* weightTensor = nil; + // Should have shape of mean + if (has_weight) + weightTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape); + MPSGraphTensor* biasTensor = nil; + if (has_bias) + biasTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(bias_opt.value()), new_mean_shape); + MPSGraphTensor* runningMeanTensor = nil; + MPSGraphTensor* runningVarTensor = nil; + if (has_running_mean) { + runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape); + runningVarTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape); + } - // Mean and inv std tensors to be saved and returned - MPSGraphTensor* saveMeanTensor = nil; - MPSGraphTensor* saveVarTensor = nil; + // Mean and inv std tensors to be saved and returned + MPSGraphTensor* saveMeanTensor = nil; + MPSGraphTensor* saveVarTensor = nil; - // Running stats inplace update - MPSGraphTensor* runningMeanInplaceUpdate = nil; - MPSGraphTensor* runningVarInplaceUpdate = nil; - - MPSGraphTensor* updatedRunningMeanTensor = nil; - MPSGraphTensor* updatedRunningVarTensor = nil; - MPSGraphTensor *scaledInverseSqrtVariance = nil; - - /* - If train: - If has_running_mean: - Update the running stats to be stored into save_mean and save_var, - AND to be used in current batchnorm computation - Else: - Just calculate the var using batch variance - If not train: - Check if running mean exists (maybe do this check before making graph) - Copy the running mean into the mean to be saved - Calculate the save_var directly from the running variance - - Compute the batch norm output and stats to be saved - */ - MPSGraphTensor *varTensor = nil; - - if(train) { - // Compute mean and variance of the current batch - MPSGraphTensor* batchMeanTensor = [mpsGraph meanOfTensor:inputTensor - axes:axes - name:nil]; - MPSGraphTensor* batchVarianceTensor = [mpsGraph varianceOfTensor:inputTensor - axes:axes - name:nil]; - varTensor = batchVarianceTensor; - if(has_running_mean) { - // TODO: This is not the formula used in PyTorch, is this OK? Seems more robust - // float besselCorrectionTerm = float(N) / std::max(N - 1.0f, 1.0f); - float besselCorrectionTerm = float(N) / float(N - 1.0f); - MPSGraphTensor* besselConstantTensor = [mpsGraph constantWithScalar:(double)besselCorrectionTerm - shape:@[@1] - dataType:input_mps_dtype]; - MPSGraphTensor* unbiasedVarianceTensor = [mpsGraph multiplicationWithPrimaryTensor:batchVarianceTensor - secondaryTensor:besselConstantTensor - name:nil]; - MPSGraphTensor* momentumTensor = [mpsGraph constantWithScalar:(double)momentum - shape:@[@1] + // Running stats inplace update + MPSGraphTensor* runningMeanInplaceUpdate = nil; + MPSGraphTensor* runningVarInplaceUpdate = nil; + + MPSGraphTensor* updatedRunningMeanTensor = nil; + MPSGraphTensor* updatedRunningVarTensor = nil; + MPSGraphTensor* scaledInverseSqrtVariance = nil; + + /* + If train: + If has_running_mean: + Update the running stats to be stored into save_mean and save_var, + AND to be used in current batchnorm computation + Else: + Just calculate the var using batch variance + If not train: + Check if running mean exists (maybe do this check before making graph) + Copy the running mean into the mean to be saved + Calculate the save_var directly from the running variance + + Compute the batch norm output and stats to be saved + */ + MPSGraphTensor* varTensor = nil; + + if (train) { + // Compute mean and variance of the current batch + MPSGraphTensor* batchMeanTensor = [mpsGraph meanOfTensor:inputTensor axes:axes name:nil]; + MPSGraphTensor* batchVarianceTensor = [mpsGraph varianceOfTensor:inputTensor axes:axes name:nil]; + varTensor = batchVarianceTensor; + if (has_running_mean) { + // TODO: This is not the formula used in PyTorch, is this OK? Seems more robust + // float besselCorrectionTerm = float(N) / std::max(N - 1.0f, 1.0f); + float besselCorrectionTerm = float(N) / float(N - 1.0f); + MPSGraphTensor* besselConstantTensor = [mpsGraph constantWithScalar:(double)besselCorrectionTerm + shape:@[ @1 ] + dataType:input_mps_dtype]; + MPSGraphTensor* unbiasedVarianceTensor = [mpsGraph multiplicationWithPrimaryTensor:batchVarianceTensor + secondaryTensor:besselConstantTensor + name:nil]; + MPSGraphTensor* momentumTensor = [mpsGraph constantWithScalar:(double)momentum + shape:@[ @1 ] + dataType:input_mps_dtype]; + MPSGraphTensor* oneMinusMomentum = [mpsGraph constantWithScalar:(double)(1.0 - momentum) + shape:@[ @1 ] dataType:input_mps_dtype]; - MPSGraphTensor* oneMinusMomentum = [mpsGraph constantWithScalar:(double)(1.0 - momentum) - shape:@[@1] - dataType:input_mps_dtype]; - // Compute updated running mean - MPSGraphTensor* scaledBatchMean = [mpsGraph multiplicationWithPrimaryTensor:batchMeanTensor - secondaryTensor:momentumTensor + // Compute updated running mean + MPSGraphTensor* scaledBatchMean = [mpsGraph multiplicationWithPrimaryTensor:batchMeanTensor + secondaryTensor:momentumTensor + name:nil]; + MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:runningMeanTensor + secondaryTensor:oneMinusMomentum name:nil]; - MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:runningMeanTensor - secondaryTensor:oneMinusMomentum - name:nil]; - updatedRunningMeanTensor = [mpsGraph additionWithPrimaryTensor:scaledBatchMean - secondaryTensor:scaledRunningMean - name:nil]; - // Compute updated running var - MPSGraphTensor* scaledCorrectedBatchVar = [mpsGraph multiplicationWithPrimaryTensor:unbiasedVarianceTensor - secondaryTensor:momentumTensor - name:nil]; - MPSGraphTensor* scaledRunningVar = [mpsGraph multiplicationWithPrimaryTensor:runningVarTensor - secondaryTensor:oneMinusMomentum - name:nil]; - updatedRunningVarTensor = [mpsGraph additionWithPrimaryTensor:scaledCorrectedBatchVar - secondaryTensor:scaledRunningVar - name:nil]; + updatedRunningMeanTensor = [mpsGraph additionWithPrimaryTensor:scaledBatchMean + secondaryTensor:scaledRunningMean + name:nil]; + // Compute updated running var + MPSGraphTensor* scaledCorrectedBatchVar = [mpsGraph multiplicationWithPrimaryTensor:unbiasedVarianceTensor + secondaryTensor:momentumTensor + name:nil]; + MPSGraphTensor* scaledRunningVar = [mpsGraph multiplicationWithPrimaryTensor:runningVarTensor + secondaryTensor:oneMinusMomentum + name:nil]; + updatedRunningVarTensor = [mpsGraph additionWithPrimaryTensor:scaledCorrectedBatchVar + secondaryTensor:scaledRunningVar + name:nil]; } // Update saved mean and inverse std tensor - MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(double)epsilon - shape:@[@1] + MPSGraphTensor* epsilonTensor = [mpsGraph constantWithScalar:(double)epsilon + shape:@[ @1 ] dataType:input_mps_dtype]; - MPSGraphTensor *varianceEps = [mpsGraph additionWithPrimaryTensor:batchVarianceTensor + MPSGraphTensor* varianceEps = [mpsGraph additionWithPrimaryTensor:batchVarianceTensor secondaryTensor:epsilonTensor name:@"varianceEps"]; - MPSGraphTensor *sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps - name:@"sqrtVariance"]; - scaledInverseSqrtVariance = [mpsGraph reciprocalWithTensor:sqrtVariance - name:nil]; + MPSGraphTensor* sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps name:@"sqrtVariance"]; + scaledInverseSqrtVariance = [mpsGraph reciprocalWithTensor:sqrtVariance name:nil]; // Update saved mean and inverse std tensor saveMeanTensor = batchMeanTensor; saveVarTensor = scaledInverseSqrtVariance; - } - else { // Test + } else { // Test TORCH_CHECK(has_running_mean); - saveMeanTensor = [mpsGraph identityWithTensor:runningMeanTensor - name:nil]; - saveVarTensor = [mpsGraph identityWithTensor:runningVarTensor - name:nil]; + saveMeanTensor = [mpsGraph identityWithTensor:runningMeanTensor name:nil]; + saveVarTensor = [mpsGraph identityWithTensor:runningVarTensor name:nil]; varTensor = saveVarTensor; } @@ -287,20 +279,16 @@ Check if running mean exists (maybe do this check before making graph) name:nil]; // Reshape saved mean and var to fit output - saveMeanTensor = [mpsGraph reshapeTensor:saveMeanTensor - withShape:@[new_mean_shape[channelsDim]] - name:nil]; - saveVarTensor = [mpsGraph reshapeTensor:saveVarTensor - withShape:@[new_mean_shape[channelsDim]] - name:nil]; + saveMeanTensor = [mpsGraph reshapeTensor:saveMeanTensor withShape:@[ new_mean_shape[channelsDim] ] name:nil]; + saveVarTensor = [mpsGraph reshapeTensor:saveVarTensor withShape:@[ new_mean_shape[channelsDim] ] name:nil]; - if(train && has_running_mean) { + if (train && has_running_mean) { // Running stats inplace update runningMeanInplaceUpdate = [mpsGraph reshapeTensor:updatedRunningMeanTensor - withShape:@[input_shape[channelsDim]] + withShape:@[ input_shape[channelsDim] ] name:nil]; runningVarInplaceUpdate = [mpsGraph reshapeTensor:updatedRunningVarTensor - withShape:@[input_shape[channelsDim]] + withShape:@[ input_shape[channelsDim] ] name:nil]; } @@ -317,175 +305,170 @@ Check if running mean exists (maybe do this check before making graph) } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, self, input_shape, executeGatherOp); auto weightPlaceholder = native_mps::Placeholder(); - if(has_weight) + if (has_weight) weightPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape); auto biasPlaceholder = native_mps::Placeholder(); - if(has_bias) + if (has_bias) biasPlaceholder = native_mps::Placeholder(cachedGraph->biasTensor_, bias_opt.value(), new_mean_shape); auto runningMeanPlaceholder = native_mps::Placeholder(); auto runningVarPlaceholder = native_mps::Placeholder(); - if(has_running_mean) { - runningMeanPlaceholder = native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape); - runningVarPlaceholder = native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape); + if (has_running_mean) { + runningMeanPlaceholder = + native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape); + runningVarPlaceholder = + native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape); } auto runningMeanInplaceUpdatePlaceholder = native_mps::Placeholder(); auto runningVarInplaceUpdatePlaceholder = native_mps::Placeholder(); - if(train && has_running_mean) { - runningMeanInplaceUpdatePlaceholder = native_mps::Placeholder(cachedGraph->runningMeanInplaceUpdate_, running_mean_opt.value()); - runningVarInplaceUpdatePlaceholder = native_mps::Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value()); + if (train && has_running_mean) { + runningMeanInplaceUpdatePlaceholder = + native_mps::Placeholder(cachedGraph->runningMeanInplaceUpdate_, running_mean_opt.value()); + runningVarInplaceUpdatePlaceholder = + native_mps::Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value()); } auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output, input_shape, false); auto saveMeanPlaceholder = native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean); auto saveVarPlaceholder = native_mps::Placeholder(cachedGraph->saveVarTensor_, save_var); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - if(has_weight) + if (has_weight) feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); - if(has_bias) + if (has_bias) feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); - if(has_running_mean) { + if (has_running_mean) { feeds[runningMeanPlaceholder.getMPSGraphTensor()] = runningMeanPlaceholder.getMPSGraphTensorData(); feeds[runningVarPlaceholder.getMPSGraphTensor()] = runningVarPlaceholder.getMPSGraphTensorData(); } - NSMutableDictionary *results = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* results = [[NSMutableDictionary new] autorelease]; results[outputPlaceholder.getMPSGraphTensor()] = outputPlaceholder.getMPSGraphTensorData(); results[saveMeanPlaceholder.getMPSGraphTensor()] = saveMeanPlaceholder.getMPSGraphTensorData(); results[saveVarPlaceholder.getMPSGraphTensor()] = saveVarPlaceholder.getMPSGraphTensorData(); // If train and has_running_mean, add updated running mean to the output - if(train && has_running_mean) { - results[runningMeanInplaceUpdatePlaceholder.getMPSGraphTensor()] = runningMeanInplaceUpdatePlaceholder.getMPSGraphTensorData(); - results[runningVarInplaceUpdatePlaceholder.getMPSGraphTensor()] = runningVarInplaceUpdatePlaceholder.getMPSGraphTensorData(); + if (train && has_running_mean) { + results[runningMeanInplaceUpdatePlaceholder.getMPSGraphTensor()] = + runningMeanInplaceUpdatePlaceholder.getMPSGraphTensorData(); + results[runningVarInplaceUpdatePlaceholder.getMPSGraphTensor()] = + runningVarInplaceUpdatePlaceholder.getMPSGraphTensorData(); } native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - if(!train) { + if (!train) { save_mean.resize_({0}); save_var.resize_({0}); } return std::tuple(output, save_mean, save_var); } -std::tuple batch_norm_mps - (const Tensor& self, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - const c10::optional& running_mean_opt, - const c10::optional& running_var_opt, - bool train, - double momentum, - double epsilon) { - +std::tuple batch_norm_mps(const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + bool train, + double momentum, + double epsilon) { const auto memory_format = self.suggest_memory_format(); - auto output = at::native::empty_mps( - self.sizes(), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - memory_format); + auto output = + at::native::empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format); int64_t n_input = self.size(1); - auto save_mean = at::native::empty_mps( - {n_input}, - self.scalar_type(), - // TODO: Accumulate type? - // at::toAccumulateType(self.scalar_type(), /*is_cuda=*/false), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - auto save_var = at::native::empty_mps( - {n_input}, - self.scalar_type(), - // TODO: Accumulate type? - // at::toAccumulateType(self.scalar_type(), /*is_cuda=*/false), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - - at::native::batch_norm_mps_out( - self, - weight_opt, - bias_opt, - running_mean_opt, - running_var_opt, - train, - momentum, - epsilon, - output, - save_mean, - save_var); + auto save_mean = at::native::empty_mps({n_input}, + self.scalar_type(), + // TODO: Accumulate type? + // at::toAccumulateType(self.scalar_type(), /*is_cuda=*/false), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + auto save_var = at::native::empty_mps({n_input}, + self.scalar_type(), + // TODO: Accumulate type? + // at::toAccumulateType(self.scalar_type(), /*is_cuda=*/false), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + at::native::batch_norm_mps_out(self, + weight_opt, + bias_opt, + running_mean_opt, + running_var_opt, + train, + momentum, + epsilon, + output, + save_mean, + save_var); return std::make_tuple(output, save_mean, save_var); } -std::tuple _batch_norm_legit_mps - (const Tensor& self, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - Tensor& running_mean, - Tensor& running_var, - bool train, - double momentum, - double epsilon) { - +std::tuple _batch_norm_legit_mps(const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + Tensor& running_mean, + Tensor& running_var, + bool train, + double momentum, + double epsilon) { return batch_norm_mps(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon); } -std::tuple _batch_norm_legit_no_stats_mps - (const Tensor& self, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - bool train, - double momentum, - double epsilon) { - +std::tuple _batch_norm_legit_no_stats_mps(const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + bool train, + double momentum, + double epsilon) { return batch_norm_mps(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon); } -std::tuple _batch_norm_legit_mps_out - (const Tensor& self, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - Tensor& running_mean, - Tensor& running_var, - bool train, double momentum, double epsilon, - Tensor& output, - Tensor& save_mean, - Tensor& save_var) { - return batch_norm_mps_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_var); +std::tuple _batch_norm_legit_mps_out(const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + Tensor& running_mean, + Tensor& running_var, + bool train, + double momentum, + double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_var) { + return batch_norm_mps_out( + self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_var); } -std::tuple _batch_norm_legit_no_stats_mps_out - (const Tensor& self, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - bool train, double momentum, double epsilon, - Tensor& output, - Tensor& save_mean, - Tensor& save_var) { - return batch_norm_mps_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var); +std::tuple _batch_norm_legit_no_stats_mps_out(const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + bool train, + double momentum, + double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_var) { + return batch_norm_mps_out( + self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var); } string get_mem_string(c10::MemoryFormat memory_format) { string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -500,18 +483,16 @@ string get_mem_string(c10::MemoryFormat memory_format) { } // Batch norm backward -std::tuple batch_norm_backward_mps - (const Tensor& grad_out, - const Tensor& input, - const c10::optional& weight_opt, - const c10::optional& running_mean_opt, - const c10::optional& running_var_opt, - const c10::optional& save_mean_opt, - const c10::optional& save_var_opt, - bool train, - double epsilon, - std::array grad_input_mask) { - +std::tuple batch_norm_backward_mps(const Tensor& grad_out, + const Tensor& input, + const c10::optional& weight_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + const c10::optional& save_mean_opt, + const c10::optional& save_var_opt, + bool train, + double epsilon, + std::array grad_input_mask) { Tensor grad_input; Tensor grad_weight; Tensor grad_bias; @@ -519,12 +500,8 @@ string get_mem_string(c10::MemoryFormat memory_format) { const auto memory_format = input.suggest_memory_format(); if (grad_input_mask[0]) { - grad_input = at::native::empty_mps(input.sizes(), - input.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - memory_format); + grad_input = + at::native::empty_mps(input.sizes(), input.scalar_type(), c10::nullopt, kMPS, c10::nullopt, memory_format); } // Assuming that if grad_input_mask of weight is 1, then the weight is available if (grad_input_mask[1]) { @@ -547,9 +524,8 @@ string get_mem_string(c10::MemoryFormat memory_format) { namespace native_mps = at::native::mps; // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* weightTensor_ = nil; @@ -580,9 +556,8 @@ string get_mem_string(c10::MemoryFormat memory_format) { } @autoreleasepool { - string mem_format_key; - switch(memory_format) { + switch (memory_format) { case at::MemoryFormat::Contiguous: mem_format_key = "Contiguous"; break; @@ -599,24 +574,21 @@ string get_mem_string(c10::MemoryFormat memory_format) { // Broadcast with input NSMutableArray* new_mean_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; // Reduction axes - NSMutableArray* axes = [NSMutableArray arrayWithCapacity:(num_input_dims-1)]; + NSMutableArray* axes = [NSMutableArray arrayWithCapacity:(num_input_dims - 1)]; get_shapes(input_shape_readonly, input_shape, new_mean_shape, axes, num_input_dims, memory_format, true); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "batch_norm_backward_mps:" + mem_format_key + ":" + std::to_string(epsilon) + ":" - + std::to_string(train) + ":" - + std::to_string(has_running_mean) + ":" - + std::to_string(has_weight) + ":" - + [ns_shape_key UTF8String] + ":" + native_mps::getMPSTypeString(input); + string key = "batch_norm_backward_mps:" + mem_format_key + ":" + std::to_string(epsilon) + ":" + + std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" + std::to_string(has_weight) + ":" + + [ns_shape_key UTF8String] + ":" + native_mps::getMPSTypeString(input); auto input_mps_dtype = native_mps::getMPSDataType(input); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); @@ -625,25 +597,32 @@ string get_mem_string(c10::MemoryFormat memory_format) { // NCHW - Channels dim is 1 int channelsDim = 1; - MPSGraphTensor* inputTensorOriginal = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape); + MPSGraphTensor* inputTensorOriginal = + native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape); // Shape is the ORIGINAL NCHW shape - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(grad_out), input_shape_readonly); + MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(grad_out), input_shape_readonly); MPSGraphTensor* weightTensor = nil; - if(has_weight) - weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape); + if (has_weight) + weightTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(weight_opt.value()), new_mean_shape); MPSGraphTensor* runningMeanTensor = nil; MPSGraphTensor* runningVarTensor = nil; - if(has_running_mean) { - runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape); - runningVarTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape); + if (has_running_mean) { + runningMeanTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(running_mean_opt.value()), new_mean_shape); + runningVarTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(running_var_opt.value()), new_mean_shape); } // Mean and inv std tensors to be saved and returned MPSGraphTensor* saveMeanTensor = nil; MPSGraphTensor* saveVarTensor = nil; - if(has_save_mean) { - saveMeanTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(save_mean_opt.value()), new_mean_shape); - saveVarTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(save_var_opt.value()), new_mean_shape); + if (has_save_mean) { + saveMeanTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(save_mean_opt.value()), new_mean_shape); + saveVarTensor = native_mps::mpsGraphRankedPlaceHolder( + mpsGraph, native_mps::getMPSDataType(save_var_opt.value()), new_mean_shape); } MPSGraphTensor* gradInputTensor = nil; @@ -651,7 +630,7 @@ string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* gradBiasTensor = nil; MPSGraphTensor* inputTensor = nil; - if(memory_format == at::MemoryFormat::Contiguous) + if (memory_format == at::MemoryFormat::Contiguous) inputTensor = inputTensorOriginal; else { // Reshape/transpose the input as needed @@ -661,30 +640,24 @@ string get_mem_string(c10::MemoryFormat memory_format) { auto C = input_shape[3]; inputTensor = [mpsGraph reshapeTensor:inputTensorOriginal - withShape:@[N, ([NSNumber numberWithInt:[H intValue]* [W intValue]]), C] - name:nil]; - inputTensor = [mpsGraph transposeTensor:inputTensor - dimension:1 - withDimension:2 - name:nil]; - inputTensor = [mpsGraph reshapeTensor:inputTensor - withShape:@[N, C, H, W] + withShape:@[ N, ([NSNumber numberWithInt:[H intValue] * [W intValue]]), C ] name:nil]; + inputTensor = [mpsGraph transposeTensor:inputTensor dimension:1 withDimension:2 name:nil]; + inputTensor = [mpsGraph reshapeTensor:inputTensor withShape:@[ N, C, H, W ] name:nil]; } - if(train) { + if (train) { // Use save_mean and save_var - MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:input_mps_dtype]; - MPSGraphTensor *revertSaveVarTensor = saveVarTensor; - revertSaveVarTensor = [mpsGraph reciprocalWithTensor: revertSaveVarTensor - name: nil]; - revertSaveVarTensor = [mpsGraph multiplicationWithPrimaryTensor: revertSaveVarTensor - secondaryTensor: revertSaveVarTensor - name: nil]; - revertSaveVarTensor = [mpsGraph subtractionWithPrimaryTensor: revertSaveVarTensor - secondaryTensor: epsilonTensor - name: nil]; - if(grad_input_mask[1]) { + MPSGraphTensor* epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:input_mps_dtype]; + MPSGraphTensor* revertSaveVarTensor = saveVarTensor; + revertSaveVarTensor = [mpsGraph reciprocalWithTensor:revertSaveVarTensor name:nil]; + revertSaveVarTensor = [mpsGraph multiplicationWithPrimaryTensor:revertSaveVarTensor + secondaryTensor:revertSaveVarTensor + name:nil]; + revertSaveVarTensor = [mpsGraph subtractionWithPrimaryTensor:revertSaveVarTensor + secondaryTensor:epsilonTensor + name:nil]; + if (grad_input_mask[1]) { gradWeightTensor = [mpsGraph normalizationGammaGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor meanTensor:saveMeanTensor @@ -693,13 +666,13 @@ string get_mem_string(c10::MemoryFormat memory_format) { epsilon:(float)epsilon name:nil]; } - if(grad_input_mask[2]) { + if (grad_input_mask[2]) { gradBiasTensor = [mpsGraph normalizationBetaGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor reductionAxes:axes name:nil]; } - if(grad_input_mask[0]) { + if (grad_input_mask[0]) { gradInputTensor = [mpsGraph normalizationGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor meanTensor:saveMeanTensor @@ -708,63 +681,53 @@ string get_mem_string(c10::MemoryFormat memory_format) { gammaGradientTensor:gradWeightTensor betaGradientTensor:gradBiasTensor reductionAxes:axes - epsilon:(float) epsilon + epsilon:(float)epsilon name:nil]; } - } - else { + } else { // Use running mean and running var MPSGraphTensor* rsqrtTensor = nil; MPSGraphTensor* epsilonTensor = nil; - if(grad_input_mask[1]) { - epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon - shape:@[@1] - dataType:input_mps_dtype]; + if (grad_input_mask[1]) { + epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon shape:@[ @1 ] dataType:input_mps_dtype]; MPSGraphTensor* xMinusMean = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:runningMeanTensor name:nil]; MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor secondaryTensor:epsilonTensor name:nil]; - rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor - name:nil]; + rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; MPSGraphTensor* bnForwardTensor = [mpsGraph multiplicationWithPrimaryTensor:xMinusMean secondaryTensor:rsqrtTensor name:nil]; MPSGraphTensor* gradBnMulTensor = [mpsGraph multiplicationWithPrimaryTensor:bnForwardTensor secondaryTensor:gradOutputTensor name:nil]; - gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor - axes:axes - name:nil]; + gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor axes:axes name:nil]; } - if(grad_input_mask[2]) { + if (grad_input_mask[2]) { gradBiasTensor = [mpsGraph normalizationBetaGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor reductionAxes:axes name:nil]; } - if(grad_input_mask[0]) { - + if (grad_input_mask[0]) { MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:input_shape_readonly dataType:input_mps_dtype]; - if(!epsilonTensor) - epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon - shape:@[@1] - dataType:input_mps_dtype]; - if(!rsqrtTensor) { + if (!epsilonTensor) + epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon shape:@[ @1 ] dataType:input_mps_dtype]; + if (!rsqrtTensor) { MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor - secondaryTensor:epsilonTensor - name:nil]; - rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor - name:nil]; + secondaryTensor:epsilonTensor + name:nil]; + rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; } gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:unitTensor secondaryTensor:rsqrtTensor name:nil]; - if(has_weight) + if (has_weight) gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradInputTensor secondaryTensor:weightTensor name:nil]; @@ -774,20 +737,20 @@ string get_mem_string(c10::MemoryFormat memory_format) { } } - if(grad_input_mask[1]) { + if (grad_input_mask[1]) { gradWeightTensor = [mpsGraph reshapeTensor:gradWeightTensor - withShape:@[input_shape_readonly[channelsDim]] + withShape:@[ input_shape_readonly[channelsDim] ] name:nil]; } - if(grad_input_mask[2]) { + if (grad_input_mask[2]) { gradBiasTensor = [mpsGraph reshapeTensor:gradBiasTensor - withShape:@[input_shape_readonly[channelsDim]] + withShape:@[ input_shape_readonly[channelsDim] ] name:nil]; } MPSGraphTensor* gradInputTensorFinal = nil; - if(memory_format == at::MemoryFormat::Contiguous) + if (memory_format == at::MemoryFormat::Contiguous) gradInputTensorFinal = gradInputTensor; else { // Reshape/transpose the input as needed @@ -796,16 +759,12 @@ string get_mem_string(c10::MemoryFormat memory_format) { auto W = input_shape[2]; auto C = input_shape[3]; - gradInputTensorFinal = [mpsGraph reshapeTensor:gradInputTensor - withShape:@[N, C, ([NSNumber numberWithInt:[H intValue]* [W intValue]])] - name:nil]; - gradInputTensorFinal = [mpsGraph transposeTensor:gradInputTensorFinal - dimension:1 - withDimension:2 - name:nil]; - gradInputTensorFinal = [mpsGraph reshapeTensor:gradInputTensorFinal - withShape:@[N, H, W, C] - name:nil]; + gradInputTensorFinal = + [mpsGraph reshapeTensor:gradInputTensor + withShape:@[ N, C, ([NSNumber numberWithInt:[H intValue] * [W intValue]]) ] + name:nil]; + gradInputTensorFinal = [mpsGraph transposeTensor:gradInputTensorFinal dimension:1 withDimension:2 name:nil]; + gradInputTensorFinal = [mpsGraph reshapeTensor:gradInputTensorFinal withShape:@[ N, H, W, C ] name:nil]; } newCachedGraph->gradOutputTensor_ = gradOutputTensor; @@ -821,75 +780,76 @@ string get_mem_string(c10::MemoryFormat memory_format) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input, input_shape); - auto gradOutputPlaceholder = native_mps::Placeholder(cachedGraph->gradOutputTensor_, grad_out, input_shape_readonly); + auto gradOutputPlaceholder = + native_mps::Placeholder(cachedGraph->gradOutputTensor_, grad_out, input_shape_readonly); auto weightPlaceholder = native_mps::Placeholder(); - if(has_weight) + if (has_weight) weightPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape); auto runningMeanPlaceholder = native_mps::Placeholder(); auto runningVarPlaceholder = native_mps::Placeholder(); - if(has_running_mean) { - runningMeanPlaceholder = native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape); - runningVarPlaceholder = native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape); + if (has_running_mean) { + runningMeanPlaceholder = + native_mps::Placeholder(cachedGraph->runningMeanTensor_, running_mean_opt.value(), new_mean_shape); + runningVarPlaceholder = + native_mps::Placeholder(cachedGraph->runningVarTensor_, running_var_opt.value(), new_mean_shape); } auto saveMeanPlaceholder = native_mps::Placeholder(); auto saveVarPlaceholder = native_mps::Placeholder(); - if(has_save_mean) { - saveMeanPlaceholder = native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean_opt.value(), new_mean_shape); + if (has_save_mean) { + saveMeanPlaceholder = + native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean_opt.value(), new_mean_shape); saveVarPlaceholder = native_mps::Placeholder(cachedGraph->saveVarTensor_, save_var_opt.value(), new_mean_shape); } auto gradInputPlaceholder = native_mps::Placeholder(); - if(grad_input_mask[0]) + if (grad_input_mask[0]) gradInputPlaceholder = native_mps::Placeholder(cachedGraph->gradInputTensor_, grad_input, input_shape); auto gradWeightPlaceholder = native_mps::Placeholder(); - if(grad_input_mask[1]) + if (grad_input_mask[1]) gradWeightPlaceholder = native_mps::Placeholder(cachedGraph->gradWeightTensor_, grad_weight); - auto gradBiasPlaceholder = native_mps::Placeholder();; - if(grad_input_mask[2]) + auto gradBiasPlaceholder = native_mps::Placeholder(); + ; + if (grad_input_mask[2]) gradBiasPlaceholder = native_mps::Placeholder(cachedGraph->gradBiasTensor_, grad_bias); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); - if(has_weight) + if (has_weight) feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); - if(has_running_mean) { + if (has_running_mean) { feeds[runningMeanPlaceholder.getMPSGraphTensor()] = runningMeanPlaceholder.getMPSGraphTensorData(); feeds[runningVarPlaceholder.getMPSGraphTensor()] = runningVarPlaceholder.getMPSGraphTensorData(); } - if(has_save_mean) { + if (has_save_mean) { feeds[saveMeanPlaceholder.getMPSGraphTensor()] = saveMeanPlaceholder.getMPSGraphTensorData(); feeds[saveVarPlaceholder.getMPSGraphTensor()] = saveVarPlaceholder.getMPSGraphTensorData(); } - NSMutableDictionary *results = [[NSMutableDictionary new] autorelease]; - if(grad_input_mask[0]) + NSMutableDictionary* results = [[NSMutableDictionary new] autorelease]; + if (grad_input_mask[0]) results[gradInputPlaceholder.getMPSGraphTensor()] = gradInputPlaceholder.getMPSGraphTensorData(); - if(grad_input_mask[1]) + if (grad_input_mask[1]) results[gradWeightPlaceholder.getMPSGraphTensor()] = gradWeightPlaceholder.getMPSGraphTensorData(); - if(grad_input_mask[2]) + if (grad_input_mask[2]) results[gradBiasPlaceholder.getMPSGraphTensor()] = gradBiasPlaceholder.getMPSGraphTensorData(); native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } return std::make_tuple(grad_input, grad_weight, grad_bias); - } // Layer norm forward for MPS -std::tuple layer_norm_mps( - const Tensor& input, - IntArrayRef normalized_shape, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - double eps) { - +std::tuple layer_norm_mps(const Tensor& input, + IntArrayRef normalized_shape, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + double eps) { c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); @@ -910,9 +870,14 @@ string get_mem_string(c10::MemoryFormat memory_format) { // entire channel/plane with the affine option, Layer Normalization applies // per-element scale and bias. E.g. For input {N, C, H, W}, weight for // batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}. - auto outputs = at::native_batch_norm( - input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{}, - /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps); + auto outputs = at::native_batch_norm(input_reshaped, + /*weight=*/{}, + /*bias=*/{}, + /*running_mean=*/{}, + /*running_var=*/{}, + /*training=*/true, + /*momentum=*/0, + eps); at::Tensor out = std::get<0>(outputs); out = out.view(input_shape); if (weight.defined() && bias.defined()) { @@ -938,21 +903,17 @@ string get_mem_string(c10::MemoryFormat memory_format) { return std::make_tuple(out, mean, variance); } -std::tuple layer_norm_backward_mps( - const Tensor& grad_out, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& mean, - const Tensor& rstd, - const c10::optional& weight_opt /* optional */, - const c10::optional& bias_opt /* optional */, - std::array grad_input_mask) { - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); +std::tuple layer_norm_backward_mps(const Tensor& grad_out, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& mean, + const Tensor& rstd, + const c10::optional& weight_opt /* optional */, + const c10::optional& bias_opt /* optional */, + std::array grad_input_mask) { + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - c10::MaybeOwned bias_maybe_owned = - at::borrow_from_optional_tensor(bias_opt); + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias); @@ -967,54 +928,47 @@ string get_mem_string(c10::MemoryFormat memory_format) { Tensor grad_weight; Tensor grad_bias; if (grad_input_mask[0]) { - grad_input = at::native::empty_like( - *X, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - kMPS /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous); + grad_input = at::native::empty_like(*X, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + kMPS /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); } if (grad_input_mask[1]) { - grad_weight = M > 0 ? at::native::empty_like( - *gamma, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - kMPS /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous) - : at::native::zeros_like( - *gamma, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - kMPS /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous); + grad_weight = M > 0 ? at::native::empty_like(*gamma, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + kMPS /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous) + : at::native::zeros_like(*gamma, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + kMPS /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); } if (grad_input_mask[2]) { - grad_bias = M > 0 ? at::native::empty_like( - *beta, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - kMPS /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous) - : at::native::zeros_like( - *beta, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - kMPS /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous); + grad_bias = M > 0 ? at::native::empty_like(*beta, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + kMPS /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous) + : at::native::zeros_like(*beta, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + kMPS /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); } if (M > 0) { - namespace native_mps = at::native::mps; // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public native_mps::MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* weightTensor_ = nil; @@ -1038,7 +992,6 @@ string get_mem_string(c10::MemoryFormat memory_format) { // const auto memory_format = input.suggest_memory_format(); @autoreleasepool { - MPSShape* input_shape = mps::getMPSShape(*X); MPSShape* gamma_shape = mps::getMPSShape(normalized_shape); @@ -1047,51 +1000,50 @@ string get_mem_string(c10::MemoryFormat memory_format) { NSMutableArray* gamma_axes = [NSMutableArray arrayWithCapacity:num_channel_dims]; - for(int i = 0; i < num_channel_dims; i++) + for (int i = 0; i < num_channel_dims; i++) gamma_axes[i] = [NSNumber numberWithInt:i]; // Axes along which to reduce to get "batch norm" gradient // This will be applied on shape [1, M, -1] NSMutableArray* bn_axes = [NSMutableArray arrayWithCapacity:num_normalized_dims]; - for(int i = 0; i < num_normalized_dims; i++) - bn_axes[i] = [NSNumber numberWithInt:(1+1+i)]; + for (int i = 0; i < num_normalized_dims; i++) + bn_axes[i] = [NSNumber numberWithInt:(1 + 1 + i)]; // Shape of input to do "batch norm" backward // This is [1, M, -1] - NSMutableArray* bn_shape = [NSMutableArray arrayWithCapacity:(num_normalized_dims+2)]; + NSMutableArray* bn_shape = [NSMutableArray arrayWithCapacity:(num_normalized_dims + 2)]; bn_shape[0] = [NSNumber numberWithInt:1]; bn_shape[1] = [NSNumber numberWithInt:M]; - for(int i = 0; i < num_normalized_dims; i++) - bn_shape[i+2] = input_shape[i+num_channel_dims]; + for (int i = 0; i < num_normalized_dims; i++) + bn_shape[i + 2] = input_shape[i + num_channel_dims]; // Shape of mean to do "batch norm" backward // This is [1, M, [1,1,1..1]] - NSMutableArray* bn_mean_shape = [NSMutableArray arrayWithCapacity:(num_normalized_dims+2)]; + NSMutableArray* bn_mean_shape = + [NSMutableArray arrayWithCapacity:(num_normalized_dims + 2)]; bn_mean_shape[0] = [NSNumber numberWithInt:1]; bn_mean_shape[1] = [NSNumber numberWithInt:M]; - for(int i = 0; i < num_normalized_dims; i++) - bn_mean_shape[i+2] = [NSNumber numberWithInt:1]; + for (int i = 0; i < num_normalized_dims; i++) + bn_mean_shape[i + 2] = [NSNumber numberWithInt:1]; // Shape of gamma to multiply with "batch norm" backward // This is [1, 1, -1] - NSMutableArray* bn_gamma_shape = [NSMutableArray arrayWithCapacity:(num_normalized_dims+2)]; + NSMutableArray* bn_gamma_shape = + [NSMutableArray arrayWithCapacity:(num_normalized_dims + 2)]; bn_gamma_shape[0] = [NSNumber numberWithInt:1]; bn_gamma_shape[1] = [NSNumber numberWithInt:1]; - for(int i = 0; i < num_normalized_dims; i++) - bn_gamma_shape[i+2] = input_shape[i+num_channel_dims]; + for (int i = 0; i < num_normalized_dims; i++) + bn_gamma_shape[i + 2] = input_shape[i + num_channel_dims]; - string key = "layer_norm_backward_mps:" - + std::to_string(has_weight) + ":" - + native_mps::getArrayRefString(normalized_shape) + ":" - + native_mps::getArrayRefString((*X).sizes()) + ":" - + native_mps::getMPSTypeString(*X); + string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" + + native_mps::getArrayRefString(normalized_shape) + ":" + native_mps::getArrayRefString((*X).sizes()) + ":" + + native_mps::getMPSTypeString(*X); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + native_mps::MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^native_mps::MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); @@ -1100,7 +1052,7 @@ string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, *X); MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, *dOut); MPSGraphTensor* weightTensor = nil; - if(has_weight) + if (has_weight) weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, *gamma); // Mean and inv std tensors to be saved and returned @@ -1111,7 +1063,7 @@ string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* gradWeightTensor = nil; MPSGraphTensor* gradBiasTensor = nil; - if(grad_input_mask[1]) { + if (grad_input_mask[1]) { MPSGraphTensor* xMinusMean = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:meanTensor name:nil]; @@ -1121,46 +1073,31 @@ string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* gradBnMulTensor = [mpsGraph multiplicationWithPrimaryTensor:bnForwardTensor secondaryTensor:gradOutputTensor name:nil]; - gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor - axes:gamma_axes - name:nil]; + gradWeightTensor = [mpsGraph reductionSumWithTensor:gradBnMulTensor axes:gamma_axes name:nil]; } - if(grad_input_mask[2]) { - gradBiasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor - axes:gamma_axes - name:nil]; + if (grad_input_mask[2]) { + gradBiasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axes:gamma_axes name:nil]; } - if(grad_input_mask[0]) { - + if (grad_input_mask[0]) { // Reshape input to [1, M, -1] // Reshape mean and rstd to [1, M, -1] // Reshape gamma to [1, 1, -1] (-1 has N dims) - MPSGraphTensor* bnInputTensor = [mpsGraph reshapeTensor:inputTensor - withShape:bn_shape - name:nil]; + MPSGraphTensor* bnInputTensor = [mpsGraph reshapeTensor:inputTensor withShape:bn_shape name:nil]; MPSGraphTensor* bnGradOutputTensor = [mpsGraph reshapeTensor:gradOutputTensor withShape:bn_shape name:nil]; // Do this at the end - if(has_weight) { - MPSGraphTensor* bnGammaTensor = [mpsGraph reshapeTensor:weightTensor - withShape:bn_gamma_shape - name:nil]; + if (has_weight) { + MPSGraphTensor* bnGammaTensor = [mpsGraph reshapeTensor:weightTensor withShape:bn_gamma_shape name:nil]; bnGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:bnGradOutputTensor secondaryTensor:bnGammaTensor name:nil]; } - MPSGraphTensor* bnMeanTensor = [mpsGraph reshapeTensor:meanTensor - withShape:bn_mean_shape - name:nil]; - MPSGraphTensor* bnRstdTensor = [mpsGraph reshapeTensor:rstdTensor - withShape:bn_mean_shape - name:nil]; + MPSGraphTensor* bnMeanTensor = [mpsGraph reshapeTensor:meanTensor withShape:bn_mean_shape name:nil]; + MPSGraphTensor* bnRstdTensor = [mpsGraph reshapeTensor:rstdTensor withShape:bn_mean_shape name:nil]; - MPSGraphTensor* mulTensor = [mpsGraph constantWithScalar:N - shape:@[@1] - dataType:MPSDataTypeInt32]; + MPSGraphTensor* mulTensor = [mpsGraph constantWithScalar:N shape:@[ @1 ] dataType:MPSDataTypeInt32]; MPSGraphTensor* numberToReduceTensor = mulTensor; @@ -1168,8 +1105,7 @@ string get_mem_string(c10::MemoryFormat memory_format) { toType:bnInputTensor.dataType name:@"cast2Tensor"]; - MPSGraphTensor* sizeReciprocalTensor = [mpsGraph reciprocalWithTensor:cast2Tensor - name:nil]; + MPSGraphTensor* sizeReciprocalTensor = [mpsGraph reciprocalWithTensor:cast2Tensor name:nil]; // TODO: Reduce redundant computation MPSGraphTensor* xMinusMean = [mpsGraph subtractionWithPrimaryTensor:bnInputTensor @@ -1184,13 +1120,9 @@ string get_mem_string(c10::MemoryFormat memory_format) { secondaryTensor:normalizedTensor name:nil]; - MPSGraphTensor* gammaGradient = [mpsGraph reductionSumWithTensor:bnGradMulTensor - axes:bn_axes - name:nil]; + MPSGraphTensor* gammaGradient = [mpsGraph reductionSumWithTensor:bnGradMulTensor axes:bn_axes name:nil]; - MPSGraphTensor* betaGradient = [mpsGraph reductionSumWithTensor:bnGradOutputTensor - axes:bn_axes - name:nil]; + MPSGraphTensor* betaGradient = [mpsGraph reductionSumWithTensor:bnGradOutputTensor axes:bn_axes name:nil]; MPSGraphTensor* gradient1 = [mpsGraph multiplicationWithPrimaryTensor:bnGradOutputTensor secondaryTensor:bnRstdTensor @@ -1201,15 +1133,14 @@ string get_mem_string(c10::MemoryFormat memory_format) { name:nil]; // reverseVariance is square of rstd - MPSGraphTensor* reverseVariance = [mpsGraph squareWithTensor:bnRstdTensor - name:nil]; + MPSGraphTensor* reverseVariance = [mpsGraph squareWithTensor:bnRstdTensor name:nil]; MPSGraphTensor* gradient2_2 = [mpsGraph multiplicationWithPrimaryTensor:gammaGradient secondaryTensor:reverseVariance name:nil]; MPSGraphTensor* gradient2 = [mpsGraph multiplicationWithPrimaryTensor:gradient2_1 - secondaryTensor:gradient2_2 - name:nil]; + secondaryTensor:gradient2_2 + name:nil]; MPSGraphTensor* gradient3_1 = [mpsGraph multiplicationWithPrimaryTensor:sizeReciprocalTensor secondaryTensor:betaGradient @@ -1227,21 +1158,14 @@ string get_mem_string(c10::MemoryFormat memory_format) { secondaryTensor:gradient3 name:nil]; - gradInputTensor = [mpsGraph reshapeTensor:gradient - withShape:input_shape - name:nil]; - + gradInputTensor = [mpsGraph reshapeTensor:gradient withShape:input_shape name:nil]; } - if(grad_input_mask[1]) { - gradWeightTensor = [mpsGraph reshapeTensor:gradWeightTensor - withShape:gamma_shape - name:nil]; + if (grad_input_mask[1]) { + gradWeightTensor = [mpsGraph reshapeTensor:gradWeightTensor withShape:gamma_shape name:nil]; } - if(grad_input_mask[2]) { - gradBiasTensor = [mpsGraph reshapeTensor:gradBiasTensor - withShape:gamma_shape - name:nil]; + if (grad_input_mask[2]) { + gradBiasTensor = [mpsGraph reshapeTensor:gradBiasTensor withShape:gamma_shape name:nil]; } newCachedGraph->gradOutputTensor_ = gradOutputTensor; @@ -1255,50 +1179,48 @@ string get_mem_string(c10::MemoryFormat memory_format) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, *X); auto gradOutputPlaceholder = native_mps::Placeholder(cachedGraph->gradOutputTensor_, *dOut); auto weightPlaceholder = native_mps::Placeholder(); - if(has_weight) + if (has_weight) weightPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, *gamma); auto saveMeanPlaceholder = native_mps::Placeholder(cachedGraph->meanTensor_, mean); auto saveVarPlaceholder = native_mps::Placeholder(cachedGraph->rstdTensor_, rstd); auto gradInputPlaceholder = native_mps::Placeholder(); - if(grad_input_mask[0]) + if (grad_input_mask[0]) gradInputPlaceholder = native_mps::Placeholder(cachedGraph->gradInputTensor_, grad_input); auto gradWeightPlaceholder = native_mps::Placeholder(); - if(grad_input_mask[1]) + if (grad_input_mask[1]) gradWeightPlaceholder = native_mps::Placeholder(cachedGraph->gradWeightTensor_, grad_weight); - auto gradBiasPlaceholder = native_mps::Placeholder();; - if(grad_input_mask[2]) + auto gradBiasPlaceholder = native_mps::Placeholder(); + ; + if (grad_input_mask[2]) gradBiasPlaceholder = native_mps::Placeholder(cachedGraph->gradBiasTensor_, grad_bias); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); - if(has_weight) + if (has_weight) feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); feeds[saveMeanPlaceholder.getMPSGraphTensor()] = saveMeanPlaceholder.getMPSGraphTensorData(); feeds[saveVarPlaceholder.getMPSGraphTensor()] = saveVarPlaceholder.getMPSGraphTensorData(); - NSMutableDictionary *results = [[NSMutableDictionary new] autorelease]; - if(grad_input_mask[0]) + NSMutableDictionary* results = [[NSMutableDictionary new] autorelease]; + if (grad_input_mask[0]) results[gradInputPlaceholder.getMPSGraphTensor()] = gradInputPlaceholder.getMPSGraphTensorData(); - if(grad_input_mask[1]) + if (grad_input_mask[1]) results[gradWeightPlaceholder.getMPSGraphTensor()] = gradWeightPlaceholder.getMPSGraphTensorData(); - if(grad_input_mask[2]) + if (grad_input_mask[2]) results[gradBiasPlaceholder.getMPSGraphTensor()] = gradBiasPlaceholder.getMPSGraphTensorData(); native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - - } - + } } return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); - } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm index d152dbe5eef905..d6294c78e80ca1 100644 --- a/aten/src/ATen/native/mps/operations/Pad.mm +++ b/aten/src/ATen/native/mps/operations/Pad.mm @@ -7,15 +7,18 @@ namespace mps { // Pad operations (1D/2D/3D forward and backward) -Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding, +Tensor& pad_out_template(Tensor& output, + const Tensor& input_, + IntArrayRef padding, const c10::optional& grad_output_opt, - MPSGraphPaddingMode mode, double constantValue, const string op_name) -{ - const int padding_size = (int) padding.size(); + MPSGraphPaddingMode mode, + double constantValue, + const string op_name) { + const int padding_size = (int)padding.size(); int padding_dim = padding_size / 2; // either 1D, 2D, or 3D - TORCH_CHECK(padding_size == 2 || padding_size == 4 || padding_size == 6, - "invalid padding argument of size ", padding_size); + TORCH_CHECK( + padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size); const Tensor& grad_output_ = *(at::borrow_from_optional_tensor(grad_output_opt)); const bool is_backward_pass = grad_output_.defined(); @@ -23,8 +26,13 @@ int64_t nbatch = 1; int64_t ndims = input_.ndimension(); - TORCH_CHECK(ndims >= (int64_t)padding_dim, "Length of pad should be no more than twice the number of " - "dimensions of the input. Pad length is ", padding_size, "while the input has ", ndims, "dimensions."); + TORCH_CHECK(ndims >= (int64_t)padding_dim, + "Length of pad should be no more than twice the number of " + "dimensions of the input. Pad length is ", + padding_size, + "while the input has ", + ndims, + "dimensions."); // number of input dims with ConstantPad could be less than 2 int dim_w = padding_dim; @@ -35,8 +43,9 @@ if (!is_backward_pass && mode != MPSGraphPaddingModeConstant && ndims > padding_dim) { bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0; TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) || - (ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0), - "3D or 4D (batch mode) tensor expected for input, but got: ", input_); + (ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0), + "3D or 4D (batch mode) tensor expected for input, but got: ", + input_); } if (ndims == padding_dim) { @@ -59,11 +68,11 @@ int64_t pad_t = padding_size > 2 ? padding[2] : 0; int64_t pad_b = padding_size > 2 ? padding[3] : 0; int64_t pad_front = padding_size > 4 ? padding[4] : 0; - int64_t pad_back = padding_size > 4 ? padding[5] : 0; + int64_t pad_back = padding_size > 4 ? padding[5] : 0; int64_t nplane = input_.size(dim_slices); int64_t input_w = input_.size(dim_w); - int64_t output_w = input_w + pad_l + pad_r; + int64_t output_w = input_w + pad_l + pad_r; int64_t input_h = padding_dim > 1 ? input_.size(dim_h) : 0; int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0; int64_t input_d = padding_dim > 2 ? input_.size(dim_d) : 0; @@ -73,8 +82,15 @@ if (!is_backward_pass) { TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1, - "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated " - "output H: ", output_h, " W: ", output_w); + "input (H: ", + input_h, + ", W: ", + input_w, + ") is too small. Calculated " + "output H: ", + output_h, + " W: ", + output_w); std::vector outputSizes; if (mode == MPSGraphPaddingModeConstant) { @@ -83,7 +99,7 @@ auto ori_padding_dim = padding_size / 2; auto l_diff = ndims - ori_padding_dim; - for (size_t i = 0; i < (size_t)l_diff; i ++) { + for (size_t i = 0; i < (size_t)l_diff; i++) { outputSizes.emplace_back(input_sizes[i]); } for (const auto i : c10::irange((size_t)ori_padding_dim)) { @@ -94,21 +110,39 @@ } else { // these checks aren't relevant for constant pad TORCH_CHECK(pad_l < input_w && pad_r < input_w, - "Argument #4: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_l, ", ", pad_r, - ") at dimension ", dim_w, " of input ", ndims); + "Argument #4: Padding size should be less than the corresponding " + "input dimension, but got: padding (", + pad_l, + ", ", + pad_r, + ") at dimension ", + dim_w, + " of input ", + ndims); if (padding_dim > 1) { TORCH_CHECK(pad_t < input_h && pad_b < input_h, - "Argument #6: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_t, ", ", pad_b, - ") at dimension ", dim_h, " of input ", ndims); + "Argument #6: Padding size should be less than the corresponding " + "input dimension, but got: padding (", + pad_t, + ", ", + pad_b, + ") at dimension ", + dim_h, + " of input ", + ndims); } if (padding_dim > 2) { TORCH_CHECK(pad_front < input_d && pad_back < input_d, - "Argument #8: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_front, ", ", pad_back, - ") at dimension ", dim_d, " of input ", ndims); + "Argument #8: Padding size should be less than the corresponding " + "input dimension, but got: padding (", + pad_front, + ", ", + pad_back, + ") at dimension ", + dim_d, + " of input ", + ndims); } outputSizes.insert(outputSizes.begin(), output_w); if (padding_dim >= 2) @@ -133,10 +167,16 @@ input = input_.contiguous(); } else { TORCH_CHECK(output_w == grad_output_.size(dim_w), - "gradOutput width unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w)); + "gradOutput width unexpected. Expected: ", + output_w, + ", Got: ", + grad_output_.size(dim_w)); if (padding_dim > 1) { TORCH_CHECK(output_h == grad_output_.size(dim_h), - "gradOutput height unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h)); + "gradOutput height unexpected. Expected: ", + output_h, + ", Got: ", + grad_output_.size(dim_h)); } output.resize_as_(input); if (output.numel() == 0 || grad_output_.numel() == 0) @@ -153,11 +193,11 @@ std::vector stridesVec(ndims, @(1)); for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) { - const int64_t leftIdx = pdim * 2; + const int64_t leftIdx = pdim * 2; const int64_t rightIdx = pdim * 2 + 1; const int64_t padIdx = ndims - pdim - 1; - leftPadVec [padIdx] = @(padding[leftIdx]); + leftPadVec[padIdx] = @(padding[leftIdx]); rightPadVec[padIdx] = @(padding[rightIdx]); // workaround for negative padding issue in backward pass if (is_backward_pass) { @@ -171,7 +211,7 @@ endsVec[padIdx] = @(input.size(padIdx) + padding[rightIdx]); endMask &= ~(1U << padIdx); } - // workaround for the right padding bug in Monterey + // workaround for the right padding bug in Monterey } else if (!is_macos_13_or_newer()) { if (padding[rightIdx] == 1 && padding[leftIdx] == 0) { rightPadVec[padIdx] = @(2); @@ -180,8 +220,8 @@ } } } - MPSShape *leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims]; - MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims]; + MPSShape* leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims]; + MPSShape* rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims]; MPSDataType dataType = getMPSScalarType(input.scalar_type()); // workaround for Bool type assert with Constant padding @@ -190,20 +230,20 @@ } struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { } + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor = nil, *outputTensor = nil; - MPSGraphTensor *gradOutputTensor = nil; + MPSGraphTensor* gradOutputTensor = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + - getArrayRefString(padding) + "]:" + std::to_string(constantValue); + string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) + + "]:" + std::to_string(constantValue); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -211,7 +251,7 @@ const bool needsSlice = startMask != dims_mask || endMask != dims_mask; if (!is_backward_pass) { - MPSGraphTensor *padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor + MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor withPaddingMode:mode leftPadding:leftPadding rightPadding:rightPadding @@ -219,36 +259,39 @@ name:nil]; // workaround for the right padding bug in Monterey if (needsSlice) { - newCachedGraph->outputTensor = [mpsGraph sliceTensor:padTensor - starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] - ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] - strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] - startMask:startMask - endMask:endMask - squeezeMask:0 - name:nil]; + newCachedGraph->outputTensor = + [mpsGraph sliceTensor:padTensor + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] + startMask:startMask + endMask:endMask + squeezeMask:0 + name:nil]; } else { newCachedGraph->outputTensor = padTensor; } } else { newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output)); - MPSGraphTensor *padGradTensor = [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor - sourceTensor:newCachedGraph->inputTensor - paddingMode:mode - leftPadding:leftPadding - rightPadding:rightPadding - name:nil]; + MPSGraphTensor* padGradTensor = + [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor + sourceTensor:newCachedGraph->inputTensor + paddingMode:mode + leftPadding:leftPadding + rightPadding:rightPadding + name:nil]; // workaround for negative padding issue with padGradientWithIncomingGradientTensor() if (needsSlice) { - newCachedGraph->outputTensor = [mpsGraph sliceGradientTensor:padGradTensor - fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil] - starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] - ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] - strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] - startMask:startMask - endMask:endMask - squeezeMask:0 - name:nil]; + newCachedGraph->outputTensor = + [mpsGraph sliceGradientTensor:padGradTensor + fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil] + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] + startMask:startMask + endMask:endMask + squeezeMask:0 + name:nil]; } else { newCachedGraph->outputTensor = padGradTensor; } @@ -257,19 +300,19 @@ return newCachedGraph; }); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr, true, dataType); - Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() : - Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType); + Placeholder gradOutputPlaceholder = !is_backward_pass + ? Placeholder() + : Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); if (is_backward_pass) { feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); } - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } return output; @@ -278,123 +321,156 @@ // 1D Reflection and Replication Padding TORCH_IMPL_FUNC(reflection_pad1d_out_mps) -(const Tensor& input, IntArrayRef padding, const Tensor& output) -{ - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, - MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps"); +(const Tensor& input, IntArrayRef padding, const Tensor& output) { + mps::pad_out_template(const_cast(output), + input, + padding, + c10::nullopt, + MPSGraphPaddingModeReflect, + 0.0, + "reflection_pad1d_out_mps"); } TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps) -(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) -{ +(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, - MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps"); + mps::pad_out_template(const_cast(grad_input), + input, + padding, + grad_output, + MPSGraphPaddingModeReflect, + 0.0, + "reflection_pad1d_backward_out_mps"); } TORCH_IMPL_FUNC(replication_pad1d_out_mps) -(const Tensor& input, IntArrayRef padding, const Tensor& output) -{ - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, - MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps"); +(const Tensor& input, IntArrayRef padding, const Tensor& output) { + mps::pad_out_template(const_cast(output), + input, + padding, + c10::nullopt, + MPSGraphPaddingModeClampToEdge, + 0.0, + "replication_pad1d_out_mps"); } TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps) -(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) -{ +(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, - MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps"); + mps::pad_out_template(const_cast(grad_input), + input, + padding, + grad_output, + MPSGraphPaddingModeClampToEdge, + 0.0, + "replication_pad1d_backward_out_mps"); } // 2D Reflection and Replication Padding -Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) -{ +Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) { return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__); } -Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) -{ +Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) { Tensor output = at::empty({0}, input.options()); return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__); } -Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input) -{ +Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + Tensor& grad_input) { grad_input.resize_as_(input).zero_(); return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__); } -Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) -{ +Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__); } TORCH_IMPL_FUNC(replication_pad2d_out_mps) -(const Tensor& input, IntArrayRef padding, const Tensor& output) -{ - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, - MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps"); +(const Tensor& input, IntArrayRef padding, const Tensor& output) { + mps::pad_out_template(const_cast(output), + input, + padding, + c10::nullopt, + MPSGraphPaddingModeClampToEdge, + 0.0, + "replication_pad2d_out_mps"); } -Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input) -{ +Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + Tensor& grad_input) { grad_input.resize_as_(input).zero_(); return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } -Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) -{ +Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } // 3D Reflection and Replication Padding TORCH_IMPL_FUNC(reflection_pad3d_out_mps) -(const Tensor& input, IntArrayRef padding, const Tensor& output) -{ - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, - MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps"); +(const Tensor& input, IntArrayRef padding, const Tensor& output) { + mps::pad_out_template(const_cast(output), + input, + padding, + c10::nullopt, + MPSGraphPaddingModeReflect, + 0.0, + "reflection_pad3d_out_mps"); } TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps) -(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) -{ +(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { grad_input.resize_as_(input).zero_(); - mps::pad_out_template(const_cast(grad_input), input, padding, grad_output, - MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps"); + mps::pad_out_template(const_cast(grad_input), + input, + padding, + grad_output, + MPSGraphPaddingModeReflect, + 0.0, + "reflection_pad3d_backward_out_mps"); } TORCH_IMPL_FUNC(replication_pad3d_out_mps) -(const Tensor& input, IntArrayRef padding, const Tensor& output) -{ - mps::pad_out_template(const_cast(output), input, padding, c10::nullopt, - MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps"); +(const Tensor& input, IntArrayRef padding, const Tensor& output) { + mps::pad_out_template(const_cast(output), + input, + padding, + c10::nullopt, + MPSGraphPaddingModeClampToEdge, + 0.0, + "replication_pad3d_out_mps"); } -Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input) -{ +Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + Tensor& grad_input) { grad_input.resize_as_(input).zero_(); return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } -Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) -{ +Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); } // backward pass is exlicitly handled in autograd by negating the "pad" argument -Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) -{ +Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) { if (pad.size() > 6) { TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ", "It uses View Ops default implementation to run. This may have performance implications."); return at::native::constant_pad_nd(self, pad, value); } Tensor output = at::empty({0}, self.options()); - return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__); + return mps::pad_out_template( + output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__); } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 034786e3be3fe4..67900284742c35 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -7,27 +7,25 @@ namespace mps { void addc_mul_div_out_mps(const Tensor& self, - const Tensor& tensor1, - const Tensor& tensor2, - const Scalar& value_opt, // default value = 1.0 - const Tensor& output, - const bool is_div, - const string op_name) -{ + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value_opt, // default value = 1.0 + const Tensor& output, + const bool is_div, + const string op_name) { if (value_opt.toDouble() == 0.0) { output.copy_(self); return; } - if(output.numel() == 0) { + if (output.numel() == 0) { return; } MPSStream* mpsStream = getCurrentMPSStream(); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor = nil, *outputTensor = nil; MPSGraphTensor *firstTensor = nil, *secondTensor = nil, *valueTensor = nil; }; @@ -39,42 +37,43 @@ void addc_mul_div_out_mps(const Tensor& self, CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph* newCachedGraph = nil; - ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type())); - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1); - newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2); - newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]); - - // the tensor to be optionally multiplied by value_scalar - MPSGraphTensor *multiplicandTensor = nil; - auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype); - auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype); - if (is_div) { - multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor - secondaryTensor:secondTensor - name:nil]; - } else { - multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor - secondaryTensor:secondTensor - name:nil]; - } - // the tensor to be added to input_tensor - MPSGraphTensor *addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor - secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype) - name:nil]; - auto outputTensor = [mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype) - secondaryTensor:addendTensor - name:nil]; - newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type()); + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + ScalarType common_dtype = + c10::promoteTypes(self.scalar_type(), c10::promoteTypes(tensor1.scalar_type(), tensor2.scalar_type())); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1); + newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2); + newCachedGraph->valueTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[ @1 ]); + + // the tensor to be optionally multiplied by value_scalar + MPSGraphTensor* multiplicandTensor = nil; + auto firstTensor = castMPSTensor(mpsGraph, newCachedGraph->firstTensor, common_dtype); + auto secondTensor = castMPSTensor(mpsGraph, newCachedGraph->secondTensor, common_dtype); + if (is_div) { + multiplicandTensor = [mpsGraph divisionWithPrimaryTensor:firstTensor secondaryTensor:secondTensor name:nil]; + } else { + multiplicandTensor = [mpsGraph multiplicationWithPrimaryTensor:firstTensor + secondaryTensor:secondTensor + name:nil]; } - return newCachedGraph; + // the tensor to be added to input_tensor + MPSGraphTensor* addendTensor = [mpsGraph + multiplicationWithPrimaryTensor:multiplicandTensor + secondaryTensor:castMPSTensor(mpsGraph, newCachedGraph->valueTensor, common_dtype) + name:nil]; + auto outputTensor = + [mpsGraph additionWithPrimaryTensor:castMPSTensor(mpsGraph, newCachedGraph->inputTensor, common_dtype) + secondaryTensor:addendTensor + name:nil]; + newCachedGraph->outputTensor = castMPSTensor(mpsGraph, outputTensor, output.scalar_type()); + } + return newCachedGraph; }); } @@ -93,9 +92,8 @@ void addc_mul_div_out_mps(const Tensor& self, cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_scalar), }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); } @@ -105,14 +103,12 @@ void addc_mul_div_out_mps(const Tensor& self, // APIs exposed to at::native scope TORCH_IMPL_FUNC(addcmul_out_mps) -(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) -{ +(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) { mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, false, "addcmul_out_mps"); } TORCH_IMPL_FUNC(addcdiv_out_mps) -(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) -{ +(const Tensor& self, const Tensor& tensor1, const Tensor& tensor2, const Scalar& value, const Tensor& output) { mps::addc_mul_div_out_mps(self, tensor1, tensor2, value, output, true, "addcdiv_out_mps"); } diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index ff26ff83518c4f..d366e637603552 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -1,14 +1,13 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at::native { namespace mps { -struct PoolingCachedGraph : public MPSCachedGraph -{ - PoolingCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} +struct PoolingCachedGraph : public MPSCachedGraph { + PoolingCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor = nil; MPSGraphTensor* outputTensor = nil; MPSGraphTensor* indicesTensor = nil; @@ -17,24 +16,30 @@ }; typedef MPSGraphTensor* (^PoolingOpBlock)(PoolingCachedGraph&, MPSGraphPooling2DOpDescriptor*); -#define PoolingOpFn(graph, desc) MPSGraphTensor* (mps::PoolingCachedGraph& graph, MPSGraphPooling2DOpDescriptor* desc) +#define PoolingOpFn(graph, desc) MPSGraphTensor*(mps::PoolingCachedGraph & graph, MPSGraphPooling2DOpDescriptor * desc) // Pooling ops (1D/2D forward and backward Max and Average pooling) -static void pool2d_template(const Tensor& input, const Tensor& output, +static void pool2d_template(const Tensor& input, + const Tensor& output, const c10::optional& indices_opt, const c10::optional& grad_output_opt, - IntArrayRef kernel_size, IntArrayRef stride, - IntArrayRef padding, IntArrayRef dilation, - bool ceil_mode, bool count_include_pad, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + bool count_include_pad, const c10::optional divisor_override, - PoolingOpBlock poolingBlock, const c10::string& op_name) -{ + PoolingOpBlock poolingBlock, + const c10::string& op_name) { if (input.numel() == 0) { return; } if (!is_macos_13_or_newer()) { TORCH_CHECK(input.scalar_type() != ScalarType::Long, - "MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0."); + "MPS: ", + op_name, + " op with int64 input is supported natively starting from macOS 13.0."); } const int64_t ndims = input.ndimension(); const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt)); @@ -48,14 +53,18 @@ static void pool2d_template(const Tensor& input, const Tensor& output, // be incompatible with the PyTorch's global NCHW layout. const auto memory_format = has_indices ? MemoryFormat::Contiguous : suggested_memory_format; - TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, op_name, - ": kernel_size must either be a single int, or a tuple of two ints") - TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, op_name, - ": stride must either be omitted, a single int, or a tuple of two ints") - TORCH_CHECK(padding.size() == 1 || padding.size() == 2, op_name, - ": padding must be either be a single int, or a tuple of two ints"); - TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, op_name, - ": dilation must be either a single int, or a tuple of two ints"); + TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, + op_name, + ": kernel_size must either be a single int, or a tuple of two ints") + TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, + op_name, + ": stride must either be omitted, a single int, or a tuple of two ints") + TORCH_CHECK(padding.size() == 1 || padding.size() == 2, + op_name, + ": padding must be either be a single int, or a tuple of two ints"); + TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, + op_name, + ": dilation must be either a single int, or a tuple of two ints"); if (suggested_memory_format == at::MemoryFormat::ChannelsLast) { TORCH_CHECK(ndims == 4, "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); @@ -80,8 +89,21 @@ static void pool2d_template(const Tensor& input, const Tensor& output, const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode); const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode); - pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, - nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format); + pool2d_shape_check(input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format); auto output_memory_format = output.suggest_memory_format(); // the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors @@ -90,7 +112,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output, indices.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); } if (output.numel() == 0) { - std::vector outputSizes {nInputPlane, outputHeight, outputWidth}; + std::vector outputSizes{nInputPlane, outputHeight, outputWidth}; if (ndims == 4) { outputSizes.insert(outputSizes.begin(), nbatch); } @@ -111,56 +133,57 @@ static void pool2d_template(const Tensor& input, const Tensor& output, MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + - getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" + - getArrayRefString(padding) + "]:D[" + getArrayRefString(dilation) + "]" + - (ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") + - (has_divisor ? ":divisor" : "") + ":" + - (suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); + string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + getArrayRefString(kernel_size) + + "]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) + "]:D[" + + getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") + + (has_divisor ? ":divisor" : "") + ":" + + (suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); MPSShape* inputShape = getMPSShape(input, memory_format); MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr; PoolingCachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - PoolingCachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + PoolingCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new PoolingCachedGraph(mpsGraph); - MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor - descriptorWithKernelWidth: kW - kernelHeight: kH - strideInX: dW - strideInY: dH - dilationRateInX: dilationW - dilationRateInY: dilationH - paddingLeft: padW - paddingRight: ceil_mode ? padW * dW : padW - paddingTop: padH - paddingBottom: ceil_mode ? padH * dH : padH - paddingStyle: MPSGraphPaddingStyleExplicit - dataLayout: memory_format == MemoryFormat::ChannelsLast ? - MPSGraphTensorNamedDataLayoutNHWC : - MPSGraphTensorNamedDataLayoutNCHW]; + MPSGraphPooling2DOpDescriptor* desc = + [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:kW + kernelHeight:kH + strideInX:dW + strideInY:dH + dilationRateInX:dilationW + dilationRateInY:dilationH + paddingLeft:padW + paddingRight:ceil_mode ? padW * dW : padW + paddingTop:padH + paddingBottom:ceil_mode ? padH * dH : padH + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:memory_format == MemoryFormat::ChannelsLast + ? MPSGraphTensorNamedDataLayoutNHWC + : MPSGraphTensorNamedDataLayoutNCHW]; desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false; if (has_indices) { desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D; desc.returnIndicesDataType = MPSDataTypeInt32; } - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape); + newCachedGraph->inputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape); if (is_backward_pass) { - newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape); + newCachedGraph->gradOutputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape); } if (has_divisor) { - newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[@1]); + newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[ @1 ]); } MPSGraphTensor* outputTensor = poolingBlock(*newCachedGraph, desc); // with desc.dataLayout = NHWC (i.e., ChannelsLast), the results need to be converted back to NCHW - newCachedGraph->outputTensor = memory_format == MemoryFormat::ChannelsLast ? - convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor; + newCachedGraph->outputTensor = + memory_format == MemoryFormat::ChannelsLast ? convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor; } return newCachedGraph; }); @@ -168,14 +191,16 @@ static void pool2d_template(const Tensor& input, const Tensor& output, MPSStream* mpsStream = getCurrentMPSStream(); // in case of ChannelsLast we don't perform gather() in placeholder to avoid implicit conversion to NCHW - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast); - Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() : - Placeholder(cachedGraph->gradOutputTensor, grad_output, - gradOutputShape, memory_format != MemoryFormat::ChannelsLast); + Placeholder inputPlaceholder = + Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast); + Placeholder gradOutputPlaceholder = !is_backward_pass + ? Placeholder() + : Placeholder( + cachedGraph->gradOutputTensor, grad_output, gradOutputShape, memory_format != MemoryFormat::ChannelsLast); Placeholder indicesPlaceholder = has_indices ? Placeholder(cachedGraph->indicesTensor, indices) : Placeholder(); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; - NSMutableDictionary *results = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* results = [[NSMutableDictionary new] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); results[outputPlaceholder.getMPSGraphTensor()] = outputPlaceholder.getMPSGraphTensorData(); @@ -192,7 +217,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output, } MPSScalar divisor_scalar; if (cachedGraph->divisorTensor) { - const float divisor = float(kH * kW) / (float) divisor_override.value(); + const float divisor = float(kH * kW) / (float)divisor_override.value(); divisor_scalar = getMPSScalar(divisor, ScalarType::Float); feeds[cachedGraph->divisorTensor] = getMPSGraphTensorFromScalar(mpsStream, divisor_scalar); } @@ -205,14 +230,17 @@ static void pool2d_template(const Tensor& input, const Tensor& output, } } -static void avg_pool2d_template(const Tensor& input, const Tensor& output, +static void avg_pool2d_template(const Tensor& input, + const Tensor& output, const c10::optional& grad_output_opt, - IntArrayRef kernel_size, IntArrayRef stride, - IntArrayRef padding, IntArrayRef dilation, - bool ceil_mode, bool count_include_pad, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + bool count_include_pad, const c10::optional divisor_override, - const c10::string& op_name) -{ + const c10::string& op_name) { const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt)); const bool is_backward_pass = grad_output.defined(); const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; @@ -226,12 +254,21 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output, "not supported on MPS backend. ", "Falling back on CPU. This may have performance implications."); if (!is_backward_pass) { - const_cast(output) = at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode, - count_include_pad, divisor_override).clone().to("mps"); + const_cast(output) = + at::avg_pool2d(input.to("cpu"), kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + .clone() + .to("mps"); } else { - const_cast(output) = at::avg_pool2d_backward(grad_output.to("cpu"), input.to("cpu"), - kernel_size, stride, padding, ceil_mode, count_include_pad, - divisor_override).clone().to("mps"); + const_cast(output) = at::avg_pool2d_backward(grad_output.to("cpu"), + input.to("cpu"), + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override) + .clone() + .to("mps"); } return; } @@ -239,7 +276,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output, mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); const int64_t ndims = input.ndimension(); - MPSShape *paddingShape = nil; + MPSShape* paddingShape = nil; MPSGraphTensor* paddedTensor = cachedGraph.inputTensor; // workaround for issue #103039644: mismatching MPS vs. CPU results @@ -249,14 +286,14 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output, std::vector padVec(ndims, @(0)); padVec[ndims - 1] = @(padding.size() == 1 ? padding[0] : padding[1]); padVec[ndims - 2] = @(ndims > 3 ? padding[0] : 0); - paddingShape = [NSArray arrayWithObjects: padVec.data() count:ndims]; - paddedTensor = [mpsGraph padTensor: cachedGraph.inputTensor - withPaddingMode: MPSGraphPaddingModeZero - leftPadding: paddingShape - rightPadding: paddingShape - constantValue: 0.0 - name: nil]; - paddedTensor = [mpsGraph identityWithTensor: paddedTensor name: nil]; + paddingShape = [NSArray arrayWithObjects:padVec.data() count:ndims]; + paddedTensor = [mpsGraph padTensor:cachedGraph.inputTensor + withPaddingMode:MPSGraphPaddingModeZero + leftPadding:paddingShape + rightPadding:paddingShape + constantValue:0.0 + name:nil]; + paddedTensor = [mpsGraph identityWithTensor:paddedTensor name:nil]; } else { desc.includeZeroPadToAverage = count_include_pad; } @@ -265,35 +302,33 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output, } if (!is_backward_pass) { - MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor: paddedTensor - descriptor: desc - name: nil]; + MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor:paddedTensor descriptor:desc name:nil]; if (cachedGraph.divisorTensor) { // workaround: custom divisor isn't supported by MPS backend, so we scale manually - return [mpsGraph multiplicationWithPrimaryTensor: avgPoolTensor - secondaryTensor: cachedGraph.divisorTensor - name: nil]; + return [mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor + secondaryTensor:cachedGraph.divisorTensor + name:nil]; } else { return avgPoolTensor; } } else { // backward pass MPSGraphTensor* scaledGradTensor = cachedGraph.gradOutputTensor; if (cachedGraph.divisorTensor) { - scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor: cachedGraph.gradOutputTensor - secondaryTensor: cachedGraph.divisorTensor - name: nil]; + scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor:cachedGraph.gradOutputTensor + secondaryTensor:cachedGraph.divisorTensor + name:nil]; } - MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor: scaledGradTensor - sourceTensor: paddedTensor - descriptor: desc - name: nil]; + MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DGradientWithGradientTensor:scaledGradTensor + sourceTensor:paddedTensor + descriptor:desc + name:nil]; if (explicit_padding) { - return [mpsGraph padGradientWithIncomingGradientTensor: avgPoolTensor - sourceTensor: cachedGraph.inputTensor - paddingMode: MPSGraphPaddingModeZero - leftPadding: paddingShape - rightPadding: paddingShape - name: nil]; + return [mpsGraph padGradientWithIncomingGradientTensor:avgPoolTensor + sourceTensor:cachedGraph.inputTensor + paddingMode:MPSGraphPaddingModeZero + leftPadding:paddingShape + rightPadding:paddingShape + name:nil]; } else { return avgPoolTensor; @@ -301,137 +336,199 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output, } }; - pool2d_template(input, output, c10::nullopt, grad_output_opt, kernel_size, stride, - padding, {1, 1}, ceil_mode, count_include_pad, divisor_override, - pooling_op_block, op_name); + pool2d_template(input, + output, + c10::nullopt, + grad_output_opt, + kernel_size, + stride, + padding, + {1, 1}, + ceil_mode, + count_include_pad, + divisor_override, + pooling_op_block, + op_name); } } // namespace mps -Tensor mps_max_pool2d( - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode) { - +Tensor mps_max_pool2d(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); - return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor - descriptor: desc - name: nil]; + return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; }; - mps::pool2d_template(input, output, c10::nullopt, c10::nullopt, kernel_size, stride, - padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d"); + mps::pool2d_template(input, + output, + c10::nullopt, + c10::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + c10::nullopt, + pooling_op_block, + "max_pool2d"); return output; } -Tensor mps_max_pool2d_backward( - const Tensor& grad_output, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode) { - +Tensor mps_max_pool2d_backward(const Tensor& grad_output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous); mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); - return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor - sourceTensor: cachedGraph.inputTensor - descriptor: desc - name: nil]; + return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor + sourceTensor:cachedGraph.inputTensor + descriptor:desc + name:nil]; }; - mps::pool2d_template(input, grad_input, c10::nullopt, grad_output, kernel_size, stride, - padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_backward"); + mps::pool2d_template(input, + grad_input, + c10::nullopt, + grad_output, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + c10::nullopt, + pooling_op_block, + "max_pool2d_backward"); return grad_input; } -TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)( - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode, - const Tensor& output, - const Tensor& indices) { - +TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps) +(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const Tensor& output, + const Tensor& indices) { auto indices_memory_format = indices.suggest_memory_format(); mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); - NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor - descriptor: desc - name: nil]; + NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor + descriptor:desc + name:nil]; cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); return poolOutputs[0]; }; - mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride, - padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices"); + mps::pool2d_template(input, + output, + indices, + c10::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + c10::nullopt, + pooling_op_block, + "max_pool2d_indices"); if (indices_memory_format == MemoryFormat::ChannelsLast) { const_cast(indices) = indices.to(MemoryFormat::ChannelsLast); } } -TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)( - const Tensor& grad_output, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode, - const Tensor& indices, - const Tensor& grad_input) { - +TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps) +(const Tensor& grad_output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const Tensor& indices, + const Tensor& grad_input) { mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); - return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor - sourceTensor: cachedGraph.inputTensor - descriptor: desc - name: nil]; + return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor + sourceTensor:cachedGraph.inputTensor + descriptor:desc + name:nil]; }; - mps::pool2d_template(input, grad_input, indices, grad_output, kernel_size, stride, - padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices_backward"); + mps::pool2d_template(input, + grad_input, + indices, + grad_output, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + c10::nullopt, + pooling_op_block, + "max_pool2d_indices_backward"); } -TORCH_IMPL_FUNC(avg_pool2d_out_mps) ( - const Tensor& input, - int64_t kH, - int64_t kW, - int64_t dH, - int64_t dW, - int64_t padH, - int64_t padW, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override, - const Tensor& output) { - - mps::avg_pool2d_template(input, output, c10::nullopt, {kH, kW}, {dH, dW}, {padH, padW}, - {1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d"); +TORCH_IMPL_FUNC(avg_pool2d_out_mps) +(const Tensor& input, + int64_t kH, + int64_t kW, + int64_t dH, + int64_t dW, + int64_t padH, + int64_t padW, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + const Tensor& output) { + mps::avg_pool2d_template(input, + output, + c10::nullopt, + {kH, kW}, + {dH, dW}, + {padH, padW}, + {1, 1}, + ceil_mode, + count_include_pad, + divisor_override, + "avg_pool2d"); } -TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) ( - const Tensor& gradOutput, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override, - const Tensor& gradInput) { - - mps::avg_pool2d_template(input, gradInput, gradOutput, kernel_size, stride, padding, - {1, 1}, ceil_mode, count_include_pad, divisor_override, "avg_pool2d_backward"); +TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) +(const Tensor& gradOutput, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + const Tensor& gradInput) { + mps::avg_pool2d_template(input, + gradInput, + gradOutput, + kernel_size, + stride, + padding, + {1, 1}, + ceil_mode, + count_include_pad, + divisor_override, + "avg_pool2d_backward"); } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/RangeFactories.mm b/aten/src/ATen/native/mps/operations/RangeFactories.mm index db826e8c4536c9..6442c715c5d77b 100644 --- a/aten/src/ATen/native/mps/operations/RangeFactories.mm +++ b/aten/src/ATen/native/mps/operations/RangeFactories.mm @@ -1,9 +1,9 @@ // Copyright © 2022 Apple Inc. #include +#include #include #include -#include #include #include #include @@ -15,37 +15,38 @@ namespace { struct RangeCachedGraph : public mps::MPSCachedGraph { API_AVAILABLE(macosx(12.3)) - RangeCachedGraph(MPSGraph *mpsGraph, MPSDataType dataType, int32_t shapeVal, bool needsClamp = false, bool startLessEnd = false): MPSCachedGraph(mpsGraph) { + RangeCachedGraph(MPSGraph* mpsGraph, + MPSDataType dataType, + int32_t shapeVal, + bool needsClamp = false, + bool startLessEnd = false) + : MPSCachedGraph(mpsGraph) { @autoreleasepool { auto shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:&shapeVal length:sizeof(int32_t)] - shape: @[@1] + shape:@[ @1 ] dataType:MPSDataTypeInt32]; - auto coordsTensor = [mpsGraph coordinateAlongAxis:0 - withShapeTensor:shapeTensor - name:nil]; + auto coordsTensor = [mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil]; coordsTensor = [mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"]; - startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]); - multiplyTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]); + startTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]); + multiplyTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]); auto scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor secondaryTensor:multiplyTensor name:nil]; - outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords - secondaryTensor:startTensor - name:nil]; + outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil]; if (needsClamp) { - endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[@1]); + endTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, dataType, @[ @1 ]); outputTensor = [mpsGraph clampWithTensor:outputTensor - minValueTensor: startLessEnd? startTensor : endTensor - maxValueTensor: startLessEnd? endTensor : startTensor - name: nil]; + minValueTensor:startLessEnd ? startTensor : endTensor + maxValueTensor:startLessEnd ? endTensor : startTensor + name:nil]; } } } - MPSGraphTensor *startTensor = nil; - MPSGraphTensor *endTensor = nil; - MPSGraphTensor *multiplyTensor = nil; - MPSGraphTensor *outputTensor = nil; + MPSGraphTensor* startTensor = nil; + MPSGraphTensor* endTensor = nil; + MPSGraphTensor* multiplyTensor = nil; + MPSGraphTensor* outputTensor = nil; }; } // anonymous namespace @@ -59,31 +60,37 @@ double size_d; if (std::is_same::value) { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); + size_d = std::ceil(static_cast(end.to() - start.to()) / step.to()); } else { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); + size_d = std::ceil(static_cast(end.to() - start.to()) / step.to()); } TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && std::isfinite(static_cast(xend)), + "unsupported range: ", + xstart, + " -> ", + xend); TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + "upper bound and larger bound inconsistent with step sign"); TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), - "invalid size, possible overflow?"); + "invalid size, possible overflow?"); int64_t size = static_cast(size_d); int64_t numel = result.numel(); if (numel != size) { - if(numel > 0){ - TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(), - " is ", numel, " which does not match the computed number of elements ", size, - ". Note that this may occur as a result of rounding error. " - "The out tensor will be resized to a tensor of shape (", size, ",)."); + if (numel > 0) { + TORCH_WARN("The number of elements in the out tensor of shape ", + result.sizes(), + " is ", + numel, + " which does not match the computed number of elements ", + size, + ". Note that this may occur as a result of rounding error. " + "The out tensor will be resized to a tensor of shape (", + size, + ",)."); } result.resize_({size}); } @@ -100,28 +107,27 @@ auto mpsDataType = getMPSDataType(result); @autoreleasepool { string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); - auto cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { - auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() { + auto* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { auto mpsGraph = make_mps_graph(); return new RangeCachedGraph(mpsGraph, mpsDataType, size); }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; MPSScalar startScalar = getMPSScalar(start, result.scalar_type()); feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar); MPSScalar stepScalar = getMPSScalar(step, result.scalar_type()); feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar); - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - if(!is_contiguous) { + if (!is_contiguous) { result.copy_(r); } }); @@ -139,22 +145,22 @@ // double size_d = ((xend - xstart) / xstep) + 1; double size_d; if (std::is_same::value) { - size_d = static_cast(end.to() - start.to()) - / step.to() + 1; + size_d = static_cast(end.to() - start.to()) / step.to() + 1; } else { - size_d = static_cast(end.to() - start.to()) - / step.to() + 1; + size_d = static_cast(end.to() - start.to()) / step.to() + 1; } TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && std::isfinite(static_cast(xend)), + "unsupported range: ", + xstart, + " -> ", + xend); TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + "upper bound and larger bound inconsistent with step sign"); TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), - "invalid size, possible overflow?"); + "invalid size, possible overflow?"); int64_t size = static_cast(size_d); @@ -171,28 +177,27 @@ auto mpsDataType = getMPSDataType(result); @autoreleasepool { string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); - auto cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { - auto *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph *() { + auto* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { auto mpsGraph = make_mps_graph(); return new RangeCachedGraph(mpsGraph, mpsDataType, size); }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; MPSScalar startScalar = getMPSScalar(start, result.scalar_type()); feeds[cachedGraph->startTensor] = getMPSGraphTensorFromScalar(stream, startScalar); MPSScalar stepScalar = getMPSScalar(step, result.scalar_type()); feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, stepScalar); - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - if(!is_contiguous) { + if (!is_contiguous) { result.copy_(r); } }); @@ -222,28 +227,30 @@ bool start_less_end = (start.to() <= end.to()); @autoreleasepool { - string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end); - RangeCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = + "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end); + RangeCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - RangeCachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + RangeCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new RangeCachedGraph(mpsGraph, MPSDataTypeFloat32, steps, true, start_less_end); - if(getMPSDataType(result) != MPSDataTypeFloat32) { - newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor toType:getMPSDataType(result) name:@"output"]; + if (getMPSDataType(result) != MPSDataTypeFloat32) { + newCachedGraph->outputTensor = [mpsGraph castTensor:newCachedGraph->outputTensor + toType:getMPSDataType(result) + name:@"output"]; } } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; auto multiply = (end.to() - start.to()) / ((double)steps - 1.0f); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, r); @@ -255,9 +262,8 @@ MPSScalar multiplyScalar = getMPSScalar(multiply, ScalarType::Float); feeds[cachedGraph->multiplyTensor] = getMPSGraphTensorFromScalar(stream, multiplyScalar); - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 4ac4afccd02a0e..3e5c3d7ecd7d09 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -2,25 +2,23 @@ #include #include -#include #include +#include #include -#include -#include #include -#include +#include #include +#include #include +#include namespace at::native { typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*); -#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary) +#define NormOpFn(graph, primary, secondary) \ + MPSGraphTensor*(mps::MPSBinaryCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary) -enum StdVarType { - STANDARD_VARIANCE, - STANDARD_DEVIATION -}; +enum StdVarType { STANDARD_VARIANCE, STANDARD_DEVIATION }; enum MPSReductionType { MAX, @@ -37,13 +35,12 @@ using namespace mps; -void set_apparent_shapes(NSMutableArray * &apparent_out_shape, - NSMutableArray * &apparent_in_shape, +void set_apparent_shapes(NSMutableArray*& apparent_out_shape, + NSMutableArray*& apparent_in_shape, int64_t num_reduce_dims, int64_t num_output_dims, IntArrayRef& input_shape, - NSMutableArray * &axes) { - + NSMutableArray*& axes) { if (num_reduce_dims == 0) { /* Output shape becomes a one * Input shape becomes flattened @@ -77,7 +74,7 @@ void set_apparent_shapes(NSMutableArray * &apparent_out_shape, } // Helper function to set the axes of reduction -void set_axes(NSMutableArray * &axes, +void set_axes(NSMutableArray*& axes, int64_t num_reduce_dims, OptionalIntArrayRef opt_dim, int64_t num_input_dims) { @@ -97,11 +94,10 @@ void set_axes(NSMutableArray * &axes, // Helper function to prepare axes and tensor shapes void set_axes_and_shapes(const Tensor& input_t, OptionalIntArrayRef opt_dims, - NSMutableArray * &axes, - NSMutableArray * &apparent_input_shape, - NSMutableArray * &apparent_output_shape, - NSMutableArray * &output_shape) { - + NSMutableArray*& axes, + NSMutableArray*& apparent_input_shape, + NSMutableArray*& apparent_output_shape, + NSMutableArray*& output_shape) { IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); @@ -114,12 +110,7 @@ void set_axes_and_shapes(const Tensor& input_t, set_axes(axes, num_reduce_dims, opt_dims, input_shape.size()); // Shapes - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_output_dims, - input_shape, - axes); + set_apparent_shapes(apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes); // Squeeze dims for output shape output_shape = [NSMutableArray arrayWithCapacity:0]; @@ -130,14 +121,13 @@ void set_axes_and_shapes(const Tensor& input_t, } } -void reduction_out_mps( - const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t, - MPSReductionType reduction_type, - const std::string& func_name) { +void reduction_out_mps(const Tensor& input_t, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional dtype, + const Tensor& output_t, + MPSReductionType reduction_type, + const std::string& func_name) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name); @@ -147,14 +137,14 @@ void reduction_out_mps( for (const auto dim_val : dim) { auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); TORCH_CHECK(wrap_dim < (input_shape.size() == 0 ? input_t.numel() : input_shape.size()), - func_name+": reduction dim must be in the range of input shape") + func_name + ": reduction dim must be in the range of input shape") } } - NSMutableArray *axes = nil; - NSMutableArray *apparent_input_shape = nil; - NSMutableArray *apparent_output_shape = nil; - NSMutableArray *output_shape = nil; + NSMutableArray* axes = nil; + NSMutableArray* apparent_input_shape = nil; + NSMutableArray* apparent_output_shape = nil; + NSMutableArray* output_shape = nil; set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape); NSArray* wrappedAxes = mps::getTensorAxes(input_t, opt_dim); @@ -163,8 +153,7 @@ void reduction_out_mps( if (output_t.numel() == 0 || input_t.numel() == 0) { if (reduction_type == MPSReductionType::PROD) { output_t.fill_(1); - } - else if (reduction_type == MPSReductionType::SUM) { + } else if (reduction_type == MPSReductionType::SUM) { output_t.zero_(); } return; @@ -173,20 +162,15 @@ void reduction_out_mps( @autoreleasepool { std::string dtype_str = dtype.has_value() ? mps::getMPSTypeString(dtype.value()) : ""; NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; - string key = func_name + ":" + - string([ns_key UTF8String]) + ":" + - getTensorsStringKey(input_t) + ":" + - std::to_string(keepdim) + ":" + - std::to_string(reduction_type) + ":" + - getTensorsStringKey(output_t) + ":" + - dtype_str; + string key = func_name + ":" + string([ns_key UTF8String]) + ":" + getTensorsStringKey(input_t) + ":" + + std::to_string(keepdim) + ":" + std::to_string(reduction_type) + ":" + getTensorsStringKey(output_t) + ":" + + dtype_str; using CachedGraph = MPSUnaryCachedGraph; auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -197,11 +181,11 @@ void reduction_out_mps( MPSGraphTensor* castInputTensor = inputTensor; MPSDataType inputCastType = MPSDataTypeInvalid; if (dtype.has_value() && - (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || - (dtype.value() == kLong && macOS13_3_plus))) { + (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || + (dtype.value() == kLong && macOS13_3_plus))) { inputCastType = getMPSDataType(dtype.value()); } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && - (inputScalarType != kLong || !macOS13_3_plus)) { + (inputScalarType != kLong || !macOS13_3_plus)) { inputCastType = getMPSDataType(kFloat); } else if (!is_macos_13_or_newer() && inputScalarType == kHalf) { inputCastType = getMPSDataType(kFloat); @@ -214,60 +198,41 @@ void reduction_out_mps( MPSGraphTensor* castOutputTensor = nil; if (reduction_type == MPSReductionType::SUM) { - castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::PROD) { - castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::MEAN) { - castOutputTensor = [mpsGraph meanOfTensor:castInputTensor - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph meanOfTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::COUNT_NONZERO) { - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0 - dataType:castInputTensor.dataType]; + MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0 dataType:castInputTensor.dataType]; MPSGraphTensor* nonZeros = [mpsGraph notEqualWithPrimaryTensor:castInputTensor secondaryTensor:zeros name:nil]; - castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::AMAX) { - castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::AMIN) { - castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::TRACE) { - MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:castInputTensor + MPSGraphTensor* bandPartWithTensor = [mpsGraph bandPartWithTensor:castInputTensor numLower:0 numUpper:0 name:nil]; - castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor - axes:@[@0, @1] - name:nil]; + castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil]; } else if (reduction_type == MPSReductionType::NANSUM) { // Create a 0 tensor of the same shape as inputTensor - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 - dataType:castInputTensor.dataType]; + MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; // Find NaNs - MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor - name:nil]; + MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; // Replace NaNs with 0 MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask truePredicateTensor:zeros falsePredicateTensor:castInputTensor name:nil]; // Sum - castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced - axes:wrappedAxes - name:nil]; + castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; } MPSGraphTensor* outputTensor = castOutputTensor; @@ -284,35 +249,32 @@ void reduction_out_mps( auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -TORCH_IMPL_FUNC(sum_out_mps)( - const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t) { - +TORCH_IMPL_FUNC(sum_out_mps) +(const Tensor& input_t, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional dtype, + const Tensor& output_t) { reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } -Tensor& nansum_out_mps( - const Tensor& self, - OptionalIntArrayRef dim, - bool keepdim, - c10::optional opt_dtype, - Tensor& result) { +Tensor& nansum_out_mps(const Tensor& self, + OptionalIntArrayRef dim, + bool keepdim, + c10::optional opt_dtype, + Tensor& result) { TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs"); - if (c10::isIntegralType(self.scalar_type(), true)){ + if (c10::isIntegralType(self.scalar_type(), true)) { return at::sum_out(result, self, dim, keepdim, opt_dtype); } ScalarType dtype = get_dtype_from_result(result, opt_dtype); @@ -322,11 +284,7 @@ void reduction_out_mps( return result; } -Tensor nansum_mps( - const Tensor& self, - OptionalIntArrayRef dim, - bool keepdim, - c10::optional opt_dtype) { +Tensor nansum_mps(const Tensor& self, OptionalIntArrayRef dim, bool keepdim, c10::optional opt_dtype) { ScalarType dtype = get_dtype_from_self(self, opt_dtype, true); Tensor result = create_reduction_result(self, dim, keepdim, dtype); return nansum_out_mps(self, dim, keepdim, dtype, result); @@ -334,76 +292,59 @@ Tensor nansum_mps( Tensor trace_mps_out(const Tensor& self) { Tensor output_t = at::native::empty_mps( - {}, - get_dtype_from_self(self, c10::nullopt, true), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + {}, get_dtype_from_self(self, c10::nullopt, true), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); std::vector dims(self.dim()); std::iota(dims.begin(), dims.end(), 0); - reduction_out_mps(self, IntArrayRef(dims), false, c10::nullopt, const_cast(output_t), MPSReductionType::TRACE, "trace_mps_out"); + reduction_out_mps(self, + IntArrayRef(dims), + false, + c10::nullopt, + const_cast(output_t), + MPSReductionType::TRACE, + "trace_mps_out"); return output_t; } TORCH_IMPL_FUNC(prod_out_mps) - (const Tensor& input_t, - int64_t dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t) { +(const Tensor& input_t, int64_t dim, bool keepdim, c10::optional dtype, const Tensor& output_t) { int64_t dims[1] = {dim}; reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps"); } -TORCH_IMPL_FUNC(amax_out_mps)( - const Tensor& input_t, - IntArrayRef dim, - bool keepdim, - const Tensor& output_t) { - +TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) { reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps"); } -TORCH_IMPL_FUNC(amin_out_mps)( - const Tensor& input_t, - IntArrayRef dim, - bool keepdim, - const Tensor& output_t) { - +TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) { reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps"); } -Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { +Tensor prod_mps(const Tensor& self, c10::optional opt_dtype) { std::vector dims(self.dim()); std::iota(dims.begin(), dims.end(), 0); Tensor output_t = at::native::empty_mps( - {}, - get_dtype_from_self(self, opt_dtype, true), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + {}, get_dtype_from_self(self, opt_dtype, true), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); - reduction_out_mps(self, IntArrayRef(dims), false, opt_dtype, const_cast(output_t), MPSReductionType::PROD, "prod_mps"); + reduction_out_mps( + self, IntArrayRef(dims), false, opt_dtype, const_cast(output_t), MPSReductionType::PROD, "prod_mps"); return output_t; } -Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ +Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims) { int64_t shape_size = dims.size() == 0 ? 0 : self.sizes().size() - dims.size(); int64_t out_shape = std::max(shape_size, 0LL); std::vector output_shape(out_shape); std::vector dims_vec = dims.vec(); - std::for_each(dims_vec.begin(), dims_vec.end(), [&](int64_t &n){ n = maybe_wrap_dim(n, self); }); + std::for_each(dims_vec.begin(), dims_vec.end(), [&](int64_t& n) { n = maybe_wrap_dim(n, self); }); if (out_shape != 0) { int out_dim = 0; - for (const auto self_dim: c10::irange((self.sizes().size()))) { + for (const auto self_dim : c10::irange((self.sizes().size()))) { if (std::find(dims_vec.begin(), dims_vec.end(), self_dim) == dims_vec.end()) { output_shape[out_dim++] = (self.sizes()[self_dim]); } @@ -411,40 +352,37 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ } Tensor output_t = at::native::empty_mps( - IntArrayRef(output_shape), - ScalarType::Long, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - reduction_out_mps(self, dims, false, self.scalar_type(), const_cast(output_t), MPSReductionType::COUNT_NONZERO, "count_nonzero_mps"); + IntArrayRef(output_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); + reduction_out_mps(self, + dims, + false, + self.scalar_type(), + const_cast(output_t), + MPSReductionType::COUNT_NONZERO, + "count_nonzero_mps"); return output_t; } -TORCH_IMPL_FUNC(mean_out_mps)( - const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t) { - +TORCH_IMPL_FUNC(mean_out_mps) +(const Tensor& input_t, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional dtype, + const Tensor& output_t) { reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps"); } -void impl_func_norm_mps( - const Tensor& input_tensor, - const Tensor& other_tensor, - const OptionalScalarRef& opt_p, - IntArrayRef dim, - bool keepdim, - c10::optional opt_dtype, - const Tensor& output_t, - bool cdist = false, - c10::optional input_broadcasted_shape = c10::nullopt, - NormOpBlock normOpBlock = nullptr - ) { - +void impl_func_norm_mps(const Tensor& input_tensor, + const Tensor& other_tensor, + const OptionalScalarRef& opt_p, + IntArrayRef dim, + bool keepdim, + c10::optional opt_dtype, + const Tensor& output_t, + bool cdist = false, + c10::optional input_broadcasted_shape = c10::nullopt, + NormOpBlock normOpBlock = nullptr) { if (input_tensor.numel() == 0) { return; } @@ -455,7 +393,7 @@ void impl_func_norm_mps( IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes(); - for (const auto dim_val: dim) { + for (const auto dim_val : dim) { auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); TORCH_CHECK(wrap_dim < input_shape.size(), "norm_out_mps: reduction dim must be in the range of input shape") } @@ -474,23 +412,18 @@ void impl_func_norm_mps( // For output shape calculation, assume that keepdim is true num_output_dims = num_input_dims; - NSMutableArray *apparent_output_shape = nil; - NSMutableArray *apparent_input_shape = nil; + NSMutableArray* apparent_output_shape = nil; + NSMutableArray* apparent_input_shape = nil; // Reduction axes - NSMutableArray *axes; + NSMutableArray* axes; set_axes(axes, num_reduce_dims, dim, input_shape.size()); - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_output_dims, - input_shape, - axes); + set_apparent_shapes(apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes); NSArray* wrappedAxes = mps::getTensorAxes(input_t, dim); if (cdist) { - apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy]; + apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy]; apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy]; } @@ -501,16 +434,16 @@ void impl_func_norm_mps( auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","]; - string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); - string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info; + string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; + string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); + string key = + string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info; auto cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - MPSBinaryCachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + MPSBinaryCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -521,66 +454,51 @@ void impl_func_norm_mps( newCachedGraph->otherTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, other_tensor); } - MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) : - newCachedGraph->inputTensor_; + MPSGraphTensor* inputTensor = cdist + ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) + : newCachedGraph->inputTensor_; if (opt_dtype.has_value()) { - inputTensor = [mpsGraph castTensor:inputTensor - toType:mps_input_dtype - name:@"castInputTensor"]; + inputTensor = [mpsGraph castTensor:inputTensor toType:mps_input_dtype name:@"castInputTensor"]; } - MPSGraphTensor *outputTensor; + MPSGraphTensor* outputTensor; if (pIsZero) { - MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor - name:nil]; - MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p - dataType:mps_input_dtype]; - MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor - secondaryTensor:powerValTensor - name:nil]; - outputTensor = [mpsGraph reductionSumWithTensor:powerTensor - axes:wrappedAxes - name:nil]; - } - else if (pIsPosInf) { - MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor - name:nil]; - outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor - axes:wrappedAxes - name:nil]; - } - else if (pIsNegInf) { - MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor - name:nil]; - outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor - axes:wrappedAxes - name:nil]; + MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; + MPSGraphTensor* powerValTensor = [mpsGraph constantWithScalar:p dataType:mps_input_dtype]; + MPSGraphTensor* powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor + secondaryTensor:powerValTensor + name:nil]; + outputTensor = [mpsGraph reductionSumWithTensor:powerTensor axes:wrappedAxes name:nil]; + } else if (pIsPosInf) { + MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; + outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor axes:wrappedAxes name:nil]; + } else if (pIsNegInf) { + MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; + outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor axes:wrappedAxes name:nil]; } else { - MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor - name:nil]; + MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; - MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p - dataType:mps_input_dtype]; + MPSGraphTensor* powerValTensor = [mpsGraph constantWithScalar:p dataType:mps_input_dtype]; - MPSGraphTensor *reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p - dataType:mps_input_dtype]; + MPSGraphTensor* reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p + dataType:mps_input_dtype]; - MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor - secondaryTensor:powerValTensor - name:nil]; + MPSGraphTensor* powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor + secondaryTensor:powerValTensor + name:nil]; - MPSGraphTensor *reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor - axes:wrappedAxes - name:nil]; + MPSGraphTensor* reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor + axes:wrappedAxes + name:nil]; - outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor - secondaryTensor:reciprocalPowerValTensor - name:nil]; + outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor + secondaryTensor:reciprocalPowerValTensor + name:nil]; } if (cdist) { - outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil]; + outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name:nil]; } newCachedGraph->outputTensor_ = outputTensor; @@ -593,28 +511,23 @@ void impl_func_norm_mps( auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); - NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); if (cdist) { otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other_tensor); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } TORCH_IMPL_FUNC(norm_out_mps) -(const Tensor& self, - const OptionalScalarRef opt_p, - IntArrayRef dim, - bool keepdim, - const Tensor& result) { +(const Tensor& self, const OptionalScalarRef opt_p, IntArrayRef dim, bool keepdim, const Tensor& result) { impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false); } @@ -632,14 +545,25 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c1 using namespace mps; TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D"); TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D"); - TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1)); - TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); + TORCH_CHECK(x1.size(-1) == x2.size(-1), + "X1 and X2 must have the same number of columns. X1: ", + x1.size(-1), + " X2: ", + x2.size(-1)); + TORCH_CHECK( + at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); auto device1 = x1.device().type(); - TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); + TORCH_CHECK( + at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); auto device2 = x2.device().type(); TORCH_CHECK(p >= 0, "cdist only supports non-negative p values"); TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2); - TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")"); + TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), + "device of X1 (", + x1.get_device(), + ") must match device of X2 (", + x2.get_device(), + ")"); int64_t c1 = x1.size(-1); int64_t c2 = x2.size(-1); @@ -652,8 +576,8 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c1 int64_t r1 = x1.size(-2); int64_t r2 = x2.size(-2); - //For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them. - //The last two dimensions will stay the same + // For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of + // them. The last two dimensions will stay the same IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2); IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); @@ -673,14 +597,22 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c1 NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) { MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil]; - MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil]; + MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor + toShape:getMPSShape(tensor1_expand_size) + name:nil]; + MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast + withShape:getMPSShape(tensor1_view) + name:nil]; - MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil]; - MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil]; + MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor + toShape:getMPSShape(tensor2_expand_size) + name:nil]; + MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast + withShape:getMPSShape(tensor2_view) + name:nil]; - NSMutableArray *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]]; - NSMutableArray *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]]; + NSMutableArray* inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]]; + NSMutableArray* otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]]; for (const auto i : c10::irange(tensor2_view[1])) { inputArray[i] = inputBroadcastReshape; @@ -690,27 +622,35 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c1 otherArray[i] = otherBroadcastReshape; } - MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil]; - MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil]; + MPSGraphTensor* inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil]; + MPSGraphTensor* otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil]; - - MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped - secondaryTensor: otherTensorReshaped - name: nil]; + MPSGraphTensor* inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor:inputTensorReshaped + secondaryTensor:otherTensorReshaped + name:nil]; return inputTensorPNorm; }; - c10::optional inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size())); - impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block); + c10::optional inputBroadcastSize = + c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size())); + impl_func_norm_mps(x1, + x2, + OptionalScalarRef(p), + makeArrayRef(2), + false, + c10::nullopt, + result, + /*cdist=*/true, + inputBroadcastSize, + norm_op_block); return result; } -Tensor std_var_common_impl_mps( - const Tensor & input_t, - at::OptionalIntArrayRef dim, - const c10::optional& correction, - bool keepdim, - StdVarType stdVarType) { +Tensor std_var_common_impl_mps(const Tensor& input_t, + at::OptionalIntArrayRef dim, + const c10::optional& correction, + bool keepdim, + StdVarType stdVarType) { using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); @@ -736,9 +676,9 @@ Tensor std_var_common_impl_mps( NSArray* wrappedAxes = getTensorAxes(input_t, dim); int64_t num_output_dims = 0; - NSMutableArray *axes = nil; - NSMutableArray *apparent_output_shape = nil; - NSMutableArray *apparent_input_shape = nil; + NSMutableArray* axes = nil; + NSMutableArray* apparent_output_shape = nil; + NSMutableArray* apparent_input_shape = nil; std::vector output_shape; if ((!keepdim && !use_dim) || (!keepdim && use_dim && dim_value.size() <= 0)) { @@ -763,19 +703,15 @@ Tensor std_var_common_impl_mps( num_output_dims = num_input_dims; set_axes(axes, num_reduce_dims, dim_value, num_input_dims); - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_output_dims, - input_shape, - axes); + set_apparent_shapes( + apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes); - num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims; + num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; // num_input_dims; unsigned int curr_i = 0; - for (const auto i: c10::irange(num_input_dims)) { + for (const auto i : c10::irange(num_input_dims)) { bool found = false; - for (const auto j: c10::irange(num_reduce_dims)) { + for (const auto j : c10::irange(num_reduce_dims)) { if (i == dim_value[j]) { found = true; break; @@ -801,15 +737,11 @@ Tensor std_var_common_impl_mps( num_output_dims = 0; int64_t num_reduce_dims = 0; set_axes(axes, num_reduce_dims, dim_value, input_shape.size()); - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_output_dims, - input_shape, - axes); + set_apparent_shapes( + apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes); num_output_dims = num_input_dims; - for (const auto i: c10::irange(num_input_dims)) { - output_shape.push_back((int64_t) 1); + for (const auto i : c10::irange(num_input_dims)) { + output_shape.push_back((int64_t)1); correction_n *= input_shape[i]; } // scalar --> vector case [[1.0034567]] @@ -818,16 +750,12 @@ Tensor std_var_common_impl_mps( num_output_dims = num_input_dims; set_axes(axes, num_reduce_dims, dim_value, num_input_dims); - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_output_dims, - input_shape, - axes); + set_apparent_shapes( + apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes); - num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; + num_output_dims = num_input_dims; //(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; - for(const int i : c10::irange(num_reduce_dims)) { + for (const int i : c10::irange(num_reduce_dims)) { auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); correction_n *= input_shape[wrap_dim]; } @@ -837,13 +765,12 @@ Tensor std_var_common_impl_mps( } } - Tensor output_t = at::native::empty_mps( - IntArrayRef(output_shape.data(), num_output_dims), - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + Tensor output_t = at::native::empty_mps(IntArrayRef(output_shape.data(), num_output_dims), + input_t.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); if (output_t.numel() == 0 || input_t.numel() == 0) { return output_t; @@ -859,89 +786,73 @@ Tensor std_var_common_impl_mps( string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0"; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - string key = op_key + ":" + - getTensorsStringKey(input_t) + ":" + - use_dim_info + ":" + - keepdim_info + ":" + - string([ns_key UTF8String]) + ":" + - bessel_corrected + ":" + - std::to_string(correction_value); + string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" + + string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value); auto cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor *outputVarTensor = [mpsGraph varianceOfTensor:inputTensor - axes:wrappedAxes - name:nil]; - MPSGraphTensor *outputTensor = nil; + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); + MPSGraphTensor* outputVarTensor = [mpsGraph varianceOfTensor:inputTensor axes:wrappedAxes name:nil]; + MPSGraphTensor* outputTensor = nil; if (use_correction && correction_value) { - MPSGraphTensor *besselTensor= [mpsGraph constantWithScalar:bessel_correction - dataType:getMPSDataType(input_t)]; - MPSGraphTensor *correctedTensor = [mpsGraph multiplicationWithPrimaryTensor:outputVarTensor - secondaryTensor:besselTensor - name:nil]; - outputTensor = (stdVarType == STANDARD_DEVIATION) ? - [mpsGraph squareRootWithTensor:correctedTensor name:nil] : correctedTensor; + MPSGraphTensor* besselTensor = [mpsGraph constantWithScalar:bessel_correction + dataType:getMPSDataType(input_t)]; + MPSGraphTensor* correctedTensor = [mpsGraph multiplicationWithPrimaryTensor:outputVarTensor + secondaryTensor:besselTensor + name:nil]; + outputTensor = (stdVarType == STANDARD_DEVIATION) ? [mpsGraph squareRootWithTensor:correctedTensor name:nil] + : correctedTensor; } else { - outputTensor = (stdVarType == STANDARD_DEVIATION) ? - [mpsGraph squareRootWithTensor:outputVarTensor name:nil] : outputVarTensor; + outputTensor = (stdVarType == STANDARD_DEVIATION) ? [mpsGraph squareRootWithTensor:outputVarTensor name:nil] + : outputVarTensor; } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - } + } auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return output_t; } -Tensor var_mps( - const Tensor & input_t, - at::OptionalIntArrayRef dim, - const c10::optional& correction, - bool keepdim) -{ +Tensor var_mps(const Tensor& input_t, + at::OptionalIntArrayRef dim, + const c10::optional& correction, + bool keepdim) { return std_var_common_impl_mps(input_t, dim, correction, keepdim, STANDARD_VARIANCE); } -Tensor std_mps( - const Tensor & input_t, - at::OptionalIntArrayRef dim, - const c10::optional& correction, - bool keepdim) -{ +Tensor std_mps(const Tensor& input_t, + at::OptionalIntArrayRef dim, + const c10::optional& correction, + bool keepdim) { return std_var_common_impl_mps(input_t, dim, correction, keepdim, STANDARD_DEVIATION); } TORCH_IMPL_FUNC(any_out_mps) - (const Tensor& input_t, - int64_t dim, - bool keepdim, - const Tensor& output_t) -{ +(const Tensor& input_t, int64_t dim, bool keepdim, const Tensor& output_t) { using CachedGraph = MPSUnaryCachedGraph; if (output_t.numel() == 0 || input_t.numel() == 0) { @@ -959,9 +870,9 @@ Tensor std_mps( // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_out_shape = nil; apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; - for (const auto i: c10::irange(num_input_dims)) { + for (const auto i : c10::irange(num_input_dims)) { apparent_out_shape[i] = dim_ == i ? @1 : [NSNumber numberWithInt:input_shape[i]]; } @@ -969,12 +880,13 @@ Tensor std_mps( @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + getMPSTypeString(input_t); + string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + + getMPSTypeString(input_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -982,15 +894,12 @@ Tensor std_mps( MPSDataType input_type = getMPSDataType(input_t); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); - MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor - axis:dim_ - name:nil]; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor axis:dim_ name:nil]; MPSGraphTensor* outputTensor = castOutputTensor; if (MPSDataTypeBool != [castOutputTensor dataType]) { - outputTensor = [mpsGraph castTensor:castOutputTensor - toType:MPSDataTypeBool - name:@"outputTensor"]; + outputTensor = [mpsGraph castTensor:castOutputTensor toType:MPSDataTypeBool name:@"outputTensor"]; } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -1001,11 +910,11 @@ Tensor std_mps( auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), }; @@ -1033,13 +942,12 @@ Tensor std_mps( @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("any_all_out_mps:") + getMPSShapeString(input_t_shape) +":" + getMPSTypeString(input_t); + string key = string("any_all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + getMPSTypeString(input_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -1047,10 +955,9 @@ Tensor std_mps( MPSDataType input_type = getMPSDataType(input_t); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); - MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor - axes:nil - name:nil]; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor axes:nil name:nil]; MPSGraphTensor* outputTensor = castOutputTensor; if (getMPSDataType(output_t) != [castOutputTensor dataType]) { @@ -1058,7 +965,6 @@ Tensor std_mps( } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; - } return newCachedGraph; }); @@ -1066,11 +972,11 @@ Tensor std_mps( auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), }; @@ -1079,11 +985,7 @@ Tensor std_mps( } TORCH_IMPL_FUNC(all_out_mps) - (const Tensor& input_t, - int64_t dim, - bool keepdim, - const Tensor& output_t) -{ +(const Tensor& input_t, int64_t dim, bool keepdim, const Tensor& output_t) { using CachedGraph = MPSUnaryCachedGraph; if (output_t.numel() == 0 || input_t.numel() == 0) { @@ -1101,32 +1003,32 @@ Tensor std_mps( // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_out_shape = nil; apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; for (const auto i : c10::irange(num_input_dims)) { - apparent_out_shape[i] = dim_ == i ? @1 : [NSNumber numberWithInt:input_shape[i]]; + apparent_out_shape[i] = dim_ == i ? @1 : [NSNumber numberWithInt:input_shape[i]]; } auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + getMPSTypeString(input_t); + string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + + getMPSTypeString(input_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSDataType input_type = getMPSDataType(input_t); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); - MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor - axis:dim_ - name:nil]; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor axis:dim_ name:nil]; MPSGraphTensor* outputTensor = castOutputTensor; if (MPSDataTypeBool != [castOutputTensor dataType]) { outputTensor = castMPSTensor(mpsGraph, castOutputTensor, MPSDataTypeBool); @@ -1140,11 +1042,11 @@ Tensor std_mps( auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), }; @@ -1167,22 +1069,21 @@ Tensor std_mps( @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("all_all_out_mps:") + getMPSShapeString(input_t_shape) +":" + getMPSTypeString(input_t); + string key = string("all_all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + getMPSTypeString(input_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSDataType input_type = getMPSDataType(input_t); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); - MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor - axes:nil - name:nil]; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor axes:nil name:nil]; MPSGraphTensor* outputTensor = castOutputTensor; if (MPSDataTypeBool != [castOutputTensor dataType]) { outputTensor = castMPSTensor(mpsGraph, castOutputTensor, MPSDataTypeBool); @@ -1190,7 +1091,6 @@ Tensor std_mps( newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; - } return newCachedGraph; }); @@ -1198,11 +1098,11 @@ Tensor std_mps( auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), }; @@ -1213,10 +1113,7 @@ Tensor std_mps( //----------------------------------------------------------------------- // Min and max functions -Tensor min_max_mps - (const Tensor& input_t, - MPSReductionType reduction_type, - const std::string& func_name) { +Tensor min_max_mps(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max"); @@ -1237,8 +1134,8 @@ Tensor std_mps( CachedGraph* cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); @@ -1246,17 +1143,14 @@ Tensor std_mps( MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* castOutputTensor = nil; - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); NSArray* axes = getTensorAxes(input_t); if (reduction_type == MPSReductionType::MAX) { - castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor - axes:axes - name:nil]; - } else if(reduction_type == MPSReductionType::MIN) { - castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor - axes:axes - name:nil]; + castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axes:axes name:nil]; + } else if (reduction_type == MPSReductionType::MIN) { + castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axes:axes name:nil]; } MPSGraphTensor* outputTensor = castOutputTensor; @@ -1272,15 +1166,14 @@ Tensor std_mps( } auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[@1]); + auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[ @1 ]); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } @@ -1290,24 +1183,21 @@ Tensor std_mps( // Max entire tensor into scalar result Tensor max_mps(const Tensor& input_t) { - return min_max_mps(input_t, MPSReductionType::MAX, "max_mps"); } // Min entire tensor into scalar result Tensor min_mps(const Tensor& input_t) { - return min_max_mps(input_t, MPSReductionType::MIN, "min_mps"); } -void min_max_out_mps - (const Tensor& input_t, - int64_t dim, - bool keepdim, - const Tensor& output_t, - const Tensor& indices_t, - MPSReductionType reduction_type, - const std::string& func_name) { +void min_max_out_mps(const Tensor& input_t, + int64_t dim, + bool keepdim, + const Tensor& output_t, + const Tensor& indices_t, + MPSReductionType reduction_type, + const std::string& func_name) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out"); @@ -1321,12 +1211,11 @@ Tensor min_mps(const Tensor& input_t) { } // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - MPSGraphTensor *indicesTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* indicesTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -1337,11 +1226,11 @@ Tensor min_mps(const Tensor& input_t) { // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_out_shape = nil; apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; - for (const auto i: c10::irange(num_input_dims)) { - apparent_out_shape[i] = dim_ == i ? @1: [NSNumber numberWithInt:input_shape[i]]; + for (const auto i : c10::irange(num_input_dims)) { + apparent_out_shape[i] = dim_ == i ? @1 : [NSNumber numberWithInt:input_shape[i]]; } auto stream = at::mps::getCurrentMPSStream(); @@ -1351,41 +1240,37 @@ Tensor min_mps(const Tensor& input_t) { CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);; - - if(reduction_type == MPSReductionType::MAX) { - outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor - axis:(NSInteger)dim_ - name:nil]; - } else if(reduction_type == MPSReductionType::MIN) { - outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor - axis:(NSInteger)dim_ - name:nil]; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + ; + + if (reduction_type == MPSReductionType::MAX) { + outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; + } else if (reduction_type == MPSReductionType::MIN) { + outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; } MPSGraphTensor* argreduceOutTensor = nil; - if(reduction_type == MPSReductionType::MAX) + if (reduction_type == MPSReductionType::MAX) argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:@"argmax_out"]; - else if(reduction_type == MPSReductionType::MIN) + else if (reduction_type == MPSReductionType::MIN) argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:@"argmax_out"]; - MPSGraphTensor *indicesTensor = nil; + MPSGraphTensor* indicesTensor = nil; if ([argreduceOutTensor dataType] != MPSDataTypeInt64) { - indicesTensor = [mpsGraph castTensor:argreduceOutTensor - toType:MPSDataTypeInt64 - name:@"cast_out"]; + indicesTensor = [mpsGraph castTensor:argreduceOutTensor toType:MPSDataTypeInt64 name:@"cast_out"]; } if ([outputTensor dataType] != getMPSDataType(output_t)) { @@ -1403,11 +1288,11 @@ Tensor min_mps(const Tensor& input_t) { auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); auto indicesPlaceholder = Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() }; @@ -1418,39 +1303,28 @@ Tensor min_mps(const Tensor& input_t) { // Max out with dim TORCH_IMPL_FUNC(max_out_mps) - (const Tensor& input_t, - int64_t dim, - bool keepdim, - const Tensor& output_t, - const Tensor& indices_t) { - - int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); - native::zero_numel_check_dims(input_t, dim_, "max()"); +(const Tensor& input_t, int64_t dim, bool keepdim, const Tensor& output_t, const Tensor& indices_t) { + int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); + native::zero_numel_check_dims(input_t, dim_, "max()"); - min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MAX, "max_out_mps"); + min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MAX, "max_out_mps"); } // Min out with dim TORCH_IMPL_FUNC(min_out_mps) - (const Tensor& input_t, - int64_t dim, - bool keepdim, - const Tensor& output_t, - const Tensor& indices_t) { - - int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); - native::zero_numel_check_dims(input_t, dim_, "min()"); +(const Tensor& input_t, int64_t dim, bool keepdim, const Tensor& output_t, const Tensor& indices_t) { + int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); + native::zero_numel_check_dims(input_t, dim_, "min()"); - min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MIN, "min_out_mps"); + min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MIN, "min_out_mps"); } -void argmax_argmin_out_mps - (const Tensor& input_t, - c10::optional dim, - bool keepdim, - const Tensor& output_t, - MPSReductionType reduction_type, - const std::string& func_name) { +void argmax_argmin_out_mps(const Tensor& input_t, + c10::optional dim, + bool keepdim, + const Tensor& output_t, + MPSReductionType reduction_type, + const std::string& func_name) { using CachedGraph = MPSUnaryCachedGraph; auto cache_ = MPSGraphCache::getInstance(); @@ -1460,22 +1334,22 @@ Tensor min_mps(const Tensor& input_t) { int64_t dim_ = -1; if (dim.has_value()) { - dim_ = maybe_wrap_dim(dim.value(), input_t.dim()); - zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()"); + dim_ = maybe_wrap_dim(dim.value(), input_t.dim()); + zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()"); } else { - TORCH_CHECK_INDEX( - input_t.numel() != 0, - reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()" , ": Expected reduction dim to be specified for input.numel() == 0."); - // Since input will be flattened, take argmax or argmin along 0'th dimension - dim_ = 0; + TORCH_CHECK_INDEX(input_t.numel() != 0, + reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()", + ": Expected reduction dim to be specified for input.numel() == 0."); + // Since input will be flattened, take argmax or argmin along 0'th dimension + dim_ = 0; } // Calculate the output shape according to keepdim=True // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_in_shape = nil; - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_in_shape = nil; + NSMutableArray* apparent_out_shape = nil; if (dim.has_value()) { apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; @@ -1492,7 +1366,7 @@ Tensor min_mps(const Tensor& input_t) { } if (output_t.numel() == 0) { - return; + return; } if (!apparent_in_shape) { @@ -1502,36 +1376,31 @@ Tensor min_mps(const Tensor& input_t) { auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = func_name + ":" + - to_string(dim_) + ":" + - getTensorsStringKey(input_t) + ":" + - string([ns_key UTF8String]); + string key = + func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); auto inputScalarType = input_t.scalar_type(); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(inputScalarType), apparent_in_shape); + MPSGraphTensor* inputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(inputScalarType), apparent_in_shape); MPSGraphTensor* argreduceOutTensor = nil; MPSGraphTensor* castInputTensor = inputTensor; if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && - (inputScalarType != kLong || !macOS13_3_plus)) { + (inputScalarType != kLong || !macOS13_3_plus)) { castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat); } if (reduction_type == MPSReductionType::MAX) { - argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor: castInputTensor - axis: (NSInteger)dim_ - name: nil]; + argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; } else { - argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor: castInputTensor - axis: (NSInteger)dim_ - name: nil]; + argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; } MPSGraphTensor* outputTensor = argreduceOutTensor; @@ -1539,10 +1408,11 @@ Tensor min_mps(const Tensor& input_t) { outputTensor = castMPSTensor(mpsGraph, argreduceOutTensor, output_t.scalar_type()); } - MPSGraphTensor* outputClampedTensor = [mpsGraph clampWithTensor: outputTensor - minValueTensor: [mpsGraph constantWithScalar:0 dataType:MPSDataTypeInt64] - maxValueTensor: [mpsGraph constantWithScalar:LLONG_MAX dataType:MPSDataTypeInt64] - name: nil]; + MPSGraphTensor* outputClampedTensor = + [mpsGraph clampWithTensor:outputTensor + minValueTensor:[mpsGraph constantWithScalar:0 dataType:MPSDataTypeInt64] + maxValueTensor:[mpsGraph constantWithScalar:LLONG_MAX dataType:MPSDataTypeInt64] + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputClampedTensor; @@ -1554,43 +1424,33 @@ Tensor min_mps(const Tensor& input_t) { auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } TORCH_IMPL_FUNC(argmax_out_mps) - (const Tensor& input_t, - c10::optional dim, - bool keepdim, - const Tensor& output_t) { - - argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MAX, "argmax_out_mps"); +(const Tensor& input_t, c10::optional dim, bool keepdim, const Tensor& output_t) { + argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MAX, "argmax_out_mps"); } TORCH_IMPL_FUNC(argmin_out_mps) - (const Tensor& input_t, - c10::optional dim, - bool keepdim, - const Tensor& output_t) { - - argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MIN, "argmin_out_mps"); +(const Tensor& input_t, c10::optional dim, bool keepdim, const Tensor& output_t) { + argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MIN, "argmin_out_mps"); } // Min/Max with dim -std::tuple min_max_mps( - const Tensor& input_t, - int64_t dim, - bool keepdim, - MPSReductionType reduction_type, - const std::string& func_name) { +std::tuple min_max_mps(const Tensor& input_t, + int64_t dim, + bool keepdim, + MPSReductionType reduction_type, + const std::string& func_name) { int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "max()"); @@ -1598,7 +1458,7 @@ Tensor min_mps(const Tensor& input_t) { // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_out_shape = nil; // Use this if keepdim is false int64_t num_output_dims = num_input_dims - 1; @@ -1608,7 +1468,7 @@ Tensor min_mps(const Tensor& input_t) { apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; // Counter for shape when keepdim is false int out_i = 0; - for (const auto i: c10::irange(num_input_dims)) { + for (const auto i : c10::irange(num_input_dims)) { if (dim_ == i) { apparent_out_shape[i] = @1; vec_apparent_out_shape[i] = 1; @@ -1624,38 +1484,18 @@ Tensor min_mps(const Tensor& input_t) { Tensor indices_t; if (!keepdim) { output_t = at::native::empty_mps( - IntArrayRef(vec_out_shape), - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_out_shape), input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); indices_t = at::native::empty_mps( - IntArrayRef(vec_out_shape), - ScalarType::Long, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_out_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); } else { output_t = at::native::empty_mps( - IntArrayRef(vec_apparent_out_shape), - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_apparent_out_shape), input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); indices_t = at::native::empty_mps( - IntArrayRef(vec_apparent_out_shape), - ScalarType::Long, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_apparent_out_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); } if (output_t.numel() == 0 || input_t.numel() == 0) { - return std::tuple{output_t, indices_t}; + return std::tuple{output_t, indices_t}; } min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, reduction_type, func_name); @@ -1675,9 +1515,9 @@ Tensor min_mps(const Tensor& input_t) { // Median of entire tensor into scalar result Tensor median_mps(const Tensor& input_t) { - if (!is_macos_13_or_newer()){ + if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0. ", - "Falling back on CPU. This may have performace implications."); + "Falling back on CPU. This may have performace implications."); return at::median(input_t.to("cpu")); } @@ -1691,7 +1531,7 @@ Tensor median_mps(const Tensor& input_t) { IntArrayRef input_shape = input_t.sizes(); // calculate total no. of elements in the input tensor to reduce it to one dimension - NSMutableArray *apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; + NSMutableArray* apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; int64_t num_in_elements = c10::multiply_integers(input_shape); apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements]; @@ -1703,29 +1543,26 @@ Tensor median_mps(const Tensor& input_t) { } @autoreleasepool { - string key = "median_mps:"+ mps::getMPSTypeString(input_t) + mps::getTensorsStringKey(input_t); + string key = "median_mps:" + mps::getMPSTypeString(input_t) + mps::getTensorsStringKey(input_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); - - auto reshapedTensor = [mpsGraph reshapeTensor: castInputTensor - withShape: @[@-1] - name: nil]; - auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor - axis: ((NSUInteger) (int)0) - name: nil]; - auto outputTensor = [mpsGraph sliceTensor: sortedTensor - dimension: 0 - start: ((NSUInteger) (int)((num_in_elements+1)/2 ) - 1) - length: 1 - name: nil]; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + + auto reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil]; + auto sortedTensor = [mpsGraph sortWithTensor:reshapedTensor axis:((NSUInteger)(int)0)name:nil]; + auto outputTensor = [mpsGraph sliceTensor:sortedTensor + dimension:0 + start:((NSUInteger)(int)((num_in_elements + 1) / 2) - 1) + length:1 + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -1735,15 +1572,14 @@ Tensor median_mps(const Tensor& input_t) { } auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[@1]); + auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[ @1 ]); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } @@ -1751,14 +1587,12 @@ Tensor median_mps(const Tensor& input_t) { return output_t; } - -void median_out_mps( - const Tensor& input_t, - int64_t dim, - bool keepdim, - const Tensor& output_t, - const Tensor& indices_t, - const std::string& func_name) { +void median_out_mps(const Tensor& input_t, + int64_t dim, + bool keepdim, + const Tensor& output_t, + const Tensor& indices_t, + const std::string& func_name) { if (output_t.numel() == 0) { return; } @@ -1770,12 +1604,11 @@ void median_out_mps( } // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - MPSGraphTensor *indicesTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* indicesTensor_ = nil; }; auto cache_ = MPSGraphCache::getInstance(); @@ -1788,7 +1621,7 @@ void median_out_mps( // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_out_shape = nil; apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; for (const int i : c10::irange(num_input_dims)) { @@ -1799,39 +1632,36 @@ void median_out_mps( auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t); + string key = + func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { auto mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); - MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor - axis:((NSUInteger) (int)dim_) - name:nil]; + MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor axis:((NSUInteger)(int)dim_)name:nil]; MPSGraphTensor* outputTensor = [mpsGraph sliceTensor:sortedTensor dimension:dim_ - start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + start:((NSUInteger)(int)((dim_total_elements + 1) / 2) - 1) length:1 name:nil]; MPSGraphTensor* argreduceOutTensor = nil; - argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor - axis:(NSInteger)dim_ - name:@"argmax_out"]; + argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor axis:(NSInteger)dim_ name:@"argmax_out"]; MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor - dimension:dim_ - start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) - length:1 - name:nil]; + dimension:dim_ + start:((NSUInteger)(int)((dim_total_elements + 1) / 2) - 1) + length:1 + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -1845,11 +1675,11 @@ void median_out_mps( auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); auto indicesPlaceholder = Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape); - NSDictionary *feeds = @{ + NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary *results = @{ + NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() }; @@ -1859,14 +1689,13 @@ void median_out_mps( } // in case mps sortWithTensor do not supported on macOS -std::tuple median_from_cpu( - const Tensor& self, - int64_t dim, - bool keepdim, - Tensor& valuesI, - Tensor& indicesI, - IntArrayRef vec_out_shape, - IntArrayRef vec_apparent_out_shape) { +std::tuple median_from_cpu(const Tensor& self, + int64_t dim, + bool keepdim, + Tensor& valuesI, + Tensor& indicesI, + IntArrayRef vec_out_shape, + IntArrayRef vec_apparent_out_shape) { Tensor values; Tensor indices; if (!keepdim) { @@ -1883,12 +1712,11 @@ void median_out_mps( return std::forward_as_tuple(valuesI, indicesI); } -TORCH_API ::std::tuple median_out_mps( - const at::Tensor & input_t, - int64_t dim, - bool keepdim, - at::Tensor & values, - at::Tensor & indices){ +TORCH_API ::std::tuple median_out_mps(const at::Tensor& input_t, + int64_t dim, + bool keepdim, + at::Tensor& values, + at::Tensor& indices) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out"); @@ -1899,7 +1727,7 @@ void median_out_mps( // If there is no dim argument, the input shape is flattened IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_out_shape = nil; + NSMutableArray* apparent_out_shape = nil; // Use this if keepdim is false int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1; @@ -1909,7 +1737,7 @@ void median_out_mps( apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; // Counter for shape when keepdim is false int out_i = 0; - for (const auto i: c10::irange(num_input_dims)) { + for (const auto i : c10::irange(num_input_dims)) { if (dim_ == i) { apparent_out_shape[i] = @1; vec_apparent_out_shape[i] = 1; @@ -1923,44 +1751,30 @@ void median_out_mps( if (!keepdim) { values = at::native::empty_mps( - IntArrayRef(vec_out_shape), - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_out_shape), input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); indices = at::native::empty_mps( - IntArrayRef(vec_out_shape), - ScalarType::Long, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_out_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); } else { values = at::native::empty_mps( - IntArrayRef(vec_apparent_out_shape), - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_apparent_out_shape), input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); indices = at::native::empty_mps( - IntArrayRef(vec_apparent_out_shape), - ScalarType::Long, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + IntArrayRef(vec_apparent_out_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); } if (values.numel() == 0 || input_t.numel() == 0) { - return std::tuple{values, indices}; + return std::tuple{values, indices}; } if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0.", - "Falling back on CPU. This may have performace implications."); - return median_from_cpu(input_t.to("cpu"), dim, keepdim, values, indices, IntArrayRef(vec_out_shape),IntArrayRef(vec_apparent_out_shape)); + "Falling back on CPU. This may have performace implications."); + return median_from_cpu(input_t.to("cpu"), + dim, + keepdim, + values, + indices, + IntArrayRef(vec_out_shape), + IntArrayRef(vec_apparent_out_shape)); } median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps"); diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index e715ef61245b9b..5953c58fda7e6c 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -8,8 +8,8 @@ #include #include #include -#include #include +#include #ifdef __OBJC__ #include @@ -19,8 +19,7 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) { auto nDims = self.dim(); - TORCH_CHECK(dims.size() == (size_t)nDims, - "number of dims don't match in permute"); + TORCH_CHECK(dims.size() == (size_t)nDims, "number of dims don't match in permute"); auto oldSizes = self.sizes(); auto oldStrides = self.strides(); DimVector newSizes(nDims); @@ -28,8 +27,7 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) { std::vector seen(nDims); for (const auto i : c10::irange(nDims)) { auto dim = maybe_wrap_dim(dims[i], nDims); - TORCH_CHECK(!seen[dim], - "repeated dim in permute"); + TORCH_CHECK(!seen[dim], "repeated dim in permute"); seen[dim] = true; newSizes[i] = oldSizes[dim]; newStrides[i] = oldStrides[dim]; @@ -38,16 +36,14 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) { } Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { - using namespace mps; TORCH_CHECK(repeats.size() >= (size_t)self.dim(), - "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; // Add new leading dimensions to the tensor if the @@ -58,7 +54,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end()); DimVector target_size(repeats.size()); bool zero_tensor = false; - for(const auto idx : c10::irange(repeats.size())) { + for (const auto idx : c10::irange(repeats.size())) { if (repeats[idx] == 0) { zero_tensor = true; } @@ -68,7 +64,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { Tensor expanded_tensor = self.expand(padded_size); Tensor result = at::empty(target_size, self.options()); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - if(zero_tensor || result.numel() == 0) { + if (zero_tensor || result.numel() == 0) { return result; } @@ -76,50 +72,47 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { auto inputDataType = getMPSDataType(expanded_tensor); auto outputDataType = getMPSDataType(result); if (!is_macos_13_or_newer()) { - if (expanded_tensor.scalar_type() == kBool) { + if (expanded_tensor.scalar_type() == kBool) { inputDataType = MPSDataTypeInt8; - } - if (result.scalar_type() == kBool) { + } + if (result.scalar_type() == kBool) { outputDataType = MPSDataTypeInt8; - } + } } @autoreleasepool { string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor)); - MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor - withMultiplier:getMPSShape(repeats) - name:nil]; + MPSGraphTensor* inputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor)); + MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor withMultiplier:getMPSShape(repeats) name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder( - cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType); - Placeholder outputPlaceholder = Placeholder( - cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType); + cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/ false, outputDataType); - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } @@ -142,51 +135,50 @@ kernel void repeat_interleave(constant {0} * repeat_ptr [[buf }} )METAL_REPEAT"; -static -id compileRepeatInterleaveLib(id device, const std::string& t1) { +static id compileRepeatInterleaveLib(id device, const std::string& t1) { auto key = t1; static std::unordered_map> libMap; auto it = libMap.find(key); if (it != libMap.end()) { return it->second; } - NSError *error = nil; - MTLCompileOptions *options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion: MTLLanguageVersion2_3]; - auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()] - options:options - error:&error]; - TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]); - libMap[key] = rc; - return rc; + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + auto rc = + [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(METAL_REPEAT_INTERLEAVE, t1).c_str()] + options:options + error:&error]; + TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]); + libMap[key] = rc; + return rc; } -static -id getPipelineState(id device, const std::string& t1) { +static id getPipelineState(id device, const std::string& t1) { static std::string kernel = "repeat_interleave"; auto key = kernel + t1; static std::unordered_map> cplMap; auto it = cplMap.find(key); if (it != cplMap.end()) { - return it->second; + return it->second; } - NSError *error = nil; + NSError* error = nil; auto library = compileRepeatInterleaveLib(device, t1); id func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; TORCH_CHECK(func != nil, "Can't get kernel ", kernel); auto rc = [device newComputePipelineStateWithFunction:func error:&error]; - TORCH_CHECK(rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); + TORCH_CHECK( + rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); cplMap[key] = rc; return rc; } template -void computeRepeatIndices( - index_t* repeat_ptr, - int64_t* cumsum_ptr, - index_t* result_ptr, - int64_t size, - int64_t result_size) { +void computeRepeatIndices(index_t* repeat_ptr, + int64_t* cumsum_ptr, + index_t* result_ptr, + int64_t size, + int64_t result_size) { id repeatBuffer = reinterpret_cast>(repeat_ptr); id cumsumBuffer = reinterpret_cast>(cumsum_ptr); id resultBuffer = reinterpret_cast>(result_ptr); @@ -208,7 +200,7 @@ void computeRepeatIndices( id computeEncoder = [commandBuffer computeCommandEncoder]; id pipelineState = getPipelineState(MPSDevice::getInstance()->device(), scalar_type); - [computeEncoder setComputePipelineState: pipelineState]; + [computeEncoder setComputePipelineState:pipelineState]; [computeEncoder setBuffer:repeatBuffer offset:0 atIndex:0]; [computeEncoder setBuffer:cumsumBuffer offset:0 atIndex:1]; [computeEncoder setBuffer:resultBuffer offset:0 atIndex:2]; @@ -216,7 +208,7 @@ void computeRepeatIndices( MTLSize gridSize = MTLSizeMake(size, 1, 1); NSUInteger threadsPerThreadgroup_ = pipelineState.maxTotalThreadsPerThreadgroup; if (threadsPerThreadgroup_ > size) { - threadsPerThreadgroup_ = size; + threadsPerThreadgroup_ = size; } MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1); @@ -233,14 +225,14 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional outpu if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { // #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output, // which currently doesn't support int64_t as input. Casting internally the indices to int32_t. - TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3"); + TORCH_WARN_ONCE( + "MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3"); repeat = repeat.to(kInt); } AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() { - output = repeat_interleave_common>( - repeat, output_size); + output = repeat_interleave_common>(repeat, output_size); }); return output; } -} // namespace at::native +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm index fd5ee95d2bfa27..fbc91fd005b334 100644 --- a/aten/src/ATen/native/mps/operations/RnnOps.mm +++ b/aten/src/ATen/native/mps/operations/RnnOps.mm @@ -4,9 +4,9 @@ #include #include #include -#include #include #include +#include #include #include #import @@ -15,12 +15,12 @@ namespace at::native { std::vector getTensorShape(MPSGraphTensor* mpsTensor) { - std::vector output_dimensions = {}; - auto dims = mpsTensor.shape; - for (int i = 0; i<[dims count];i++){ - output_dimensions.push_back([dims[i] intValue]); - } - return output_dimensions; + std::vector output_dimensions = {}; + auto dims = mpsTensor.shape; + for (int i = 0; i < [dims count]; i++) { + output_dimensions.push_back([dims[i] intValue]); + } + return output_dimensions; } /** @@ -29,752 +29,792 @@ * stateTensor, cellStateTensor, recurrentWeight, inputWeight, biasTensor */ static std::tuple - getMPSTensorsFromPytorchTensors(MPSGraph* mpsGraph, - MPSGraphTensor* stateTensor, MPSGraphTensor* cellStateTensor, - NSMutableArray *recurrentKernelWeightsList, - NSMutableArray *kernelWeightsList, - NSMutableArray *kernelBiasList, - NSMutableArray *recurrentBiasList, - bool has_biases, bool bidirectional, size_t layer_no) { - MPSGraphTensor* biasTensor_ = nil; - MPSGraphTensor* stateTensor_ = nil, *cellStateTensor_ = nil; - MPSGraphTensor* recurrentWeight_ = nil, *inputWeight_ = nil; - - if (bidirectional) { - stateTensor_ = [mpsGraph sliceTensor:stateTensor - dimension:0 - start:layer_no * 2 - length:2 - name:nil]; - // [2, N, H] -> [N, 2, H] - stateTensor_ = [mpsGraph transposeTensor:stateTensor_ dimension: 0 withDimension: 1 name:nil]; - // [N, 2, H] -> [N, 2 * H] - stateTensor_ = [mpsGraph flatten2DTensor:stateTensor_ axis:1 name:nil]; - cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor - dimension:0 - start:layer_no * 2 - length:2 - name:nil]; - cellStateTensor_ = [mpsGraph transposeTensor:cellStateTensor_ dimension: 0 withDimension: 1 name:nil]; - cellStateTensor_ = [mpsGraph flatten2DTensor:cellStateTensor_ axis:1 name:nil]; - - recurrentWeight_ = [mpsGraph - concatTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2] axis: 0 name: nil] - withTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2 + 1] axis: 0 name: nil] - dimension: 0 - name: nil - ]; - inputWeight_ = [mpsGraph - concatTensor: kernelWeightsList[layer_no * 2] - withTensor: kernelWeightsList[layer_no * 2 + 1] - dimension: 0 - name: nil - ]; - if (has_biases) { - auto biasTensorFwd_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2] - secondaryTensor:recurrentBiasList[layer_no * 2] - name:nil]; - auto biasTensorBack_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2 + 1] - secondaryTensor:recurrentBiasList[layer_no * 2 + 1] - name:nil]; - - biasTensor_ = [mpsGraph concatTensor:biasTensorFwd_ withTensor:biasTensorBack_ dimension:0 name:nil]; - } - } else { - stateTensor_ = [mpsGraph sliceTensor:stateTensor - dimension:0 - start:layer_no - length:1 - name:nil]; - cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor - dimension:0 - start:layer_no - length:1 - name:nil]; - recurrentWeight_ = recurrentKernelWeightsList[layer_no]; - inputWeight_ = kernelWeightsList[layer_no]; - if (has_biases) { - biasTensor_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no] - secondaryTensor:recurrentBiasList[layer_no] - name:nil]; - } - } - return std::make_tuple(stateTensor_, cellStateTensor_, recurrentWeight_, inputWeight_, biasTensor_); -} - -std::tuple _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { - using namespace mps; +getMPSTensorsFromPytorchTensors(MPSGraph* mpsGraph, + MPSGraphTensor* stateTensor, + MPSGraphTensor* cellStateTensor, + NSMutableArray* recurrentKernelWeightsList, + NSMutableArray* kernelWeightsList, + NSMutableArray* kernelBiasList, + NSMutableArray* recurrentBiasList, + bool has_biases, + bool bidirectional, + size_t layer_no) { + MPSGraphTensor* biasTensor_ = nil; + MPSGraphTensor *stateTensor_ = nil, *cellStateTensor_ = nil; + MPSGraphTensor *recurrentWeight_ = nil, *inputWeight_ = nil; + + if (bidirectional) { + stateTensor_ = [mpsGraph sliceTensor:stateTensor dimension:0 start:layer_no * 2 length:2 name:nil]; + // [2, N, H] -> [N, 2, H] + stateTensor_ = [mpsGraph transposeTensor:stateTensor_ dimension:0 withDimension:1 name:nil]; + // [N, 2, H] -> [N, 2 * H] + stateTensor_ = [mpsGraph flatten2DTensor:stateTensor_ axis:1 name:nil]; + cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor dimension:0 start:layer_no * 2 length:2 name:nil]; + cellStateTensor_ = [mpsGraph transposeTensor:cellStateTensor_ dimension:0 withDimension:1 name:nil]; + cellStateTensor_ = [mpsGraph flatten2DTensor:cellStateTensor_ axis:1 name:nil]; + + recurrentWeight_ = [mpsGraph + concatTensor:[mpsGraph expandDimsOfTensor:recurrentKernelWeightsList[layer_no * 2] axis:0 name:nil] + withTensor:[mpsGraph expandDimsOfTensor:recurrentKernelWeightsList[layer_no * 2 + 1] axis:0 name:nil] + dimension:0 + name:nil]; + inputWeight_ = [mpsGraph concatTensor:kernelWeightsList[layer_no * 2] + withTensor:kernelWeightsList[layer_no * 2 + 1] + dimension:0 + name:nil]; + if (has_biases) { + auto biasTensorFwd_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2] + secondaryTensor:recurrentBiasList[layer_no * 2] + name:nil]; + auto biasTensorBack_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2 + 1] + secondaryTensor:recurrentBiasList[layer_no * 2 + 1] + name:nil]; - //Projections are not currently supported, raise an error if needed - bool has_projections = (hx[0].size(2) != hx[1].size(2)); - if(has_projections) { - AT_ERROR("LSTM with projections is not currently supported with MPS."); + biasTensor_ = [mpsGraph concatTensor:biasTensorFwd_ withTensor:biasTensorBack_ dimension:0 name:nil]; } - - std::vector kernel_weights; - std::vector recurrent_kernel_weights; - std::vector biases; - std::vector recurrent_biases; - - const int64_t total_layers = num_layers * (bidirectional ? 2 : 1); - - for (const auto i : c10::irange(total_layers)) { - const int stride = (has_biases ? 4 : 2); - kernel_weights.push_back(params[i*stride]); - recurrent_kernel_weights.push_back(params[i*stride+1]); - - if (has_biases) { - biases.push_back(params[i*stride+2]); - recurrent_biases.push_back(params[i*stride+3]); - } + } else { + stateTensor_ = [mpsGraph sliceTensor:stateTensor dimension:0 start:layer_no length:1 name:nil]; + cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor dimension:0 start:layer_no length:1 name:nil]; + recurrentWeight_ = recurrentKernelWeightsList[layer_no]; + inputWeight_ = kernelWeightsList[layer_no]; + if (has_biases) { + biasTensor_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no] + secondaryTensor:recurrentBiasList[layer_no] + name:nil]; } + } + return std::make_tuple(stateTensor_, cellStateTensor_, recurrentWeight_, inputWeight_, biasTensor_); +} - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - std::vector inputTensors_; - std::vector outputTensors_; - NSMutableArray *kernelWeightsList_ = nil; - NSMutableArray *recurrentKernelWeightsList_ = nil; - NSMutableArray *biasList_ = nil; - NSMutableArray *recurrentBiasList_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + std::to_string(batch_first); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - NSMutableArray *kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers]; - - for (const auto i : c10::irange(total_layers)) { - [kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))]; - [recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))]; - if(has_biases) { - [kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))]; - [recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))]; - } +std::tuple _lstm_mps(const Tensor& input, + TensorList hx, + TensorList params, + bool has_biases, + int64_t num_layers, + double dropout_p, + bool train, + bool bidirectional, + bool batch_first) { + using namespace mps; + + // Projections are not currently supported, raise an error if needed + bool has_projections = (hx[0].size(2) != hx[1].size(2)); + if (has_projections) { + AT_ERROR("LSTM with projections is not currently supported with MPS."); + } + + std::vector kernel_weights; + std::vector recurrent_kernel_weights; + std::vector biases; + std::vector recurrent_biases; + + const int64_t total_layers = num_layers * (bidirectional ? 2 : 1); + + for (const auto i : c10::irange(total_layers)) { + const int stride = (has_biases ? 4 : 2); + kernel_weights.push_back(params[i * stride]); + recurrent_kernel_weights.push_back(params[i * stride + 1]); + + if (has_biases) { + biases.push_back(params[i * stride + 2]); + recurrent_biases.push_back(params[i * stride + 3]); + } + } + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + std::vector inputTensors_; + std::vector outputTensors_; + NSMutableArray* kernelWeightsList_ = nil; + NSMutableArray* recurrentKernelWeightsList_ = nil; + NSMutableArray* biasList_ = nil; + NSMutableArray* recurrentBiasList_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" + + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + + std::to_string(batch_first); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + NSMutableArray* kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* recurrentKernelWeightsList = + [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers]; + + for (const auto i : c10::irange(total_layers)) { + [kernelWeightsList + addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))]; + [recurrentKernelWeightsList + addObject:mpsGraphRankedPlaceHolder( + mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))]; + if (has_biases) { + [kernelBiasList + addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))]; + [recurrentBiasList addObject:mpsGraphRankedPlaceHolder( + mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))]; } + } - MPSGraphLSTMDescriptor * opDesc = [MPSGraphLSTMDescriptor descriptor]; - opDesc.training = true; - opDesc.bidirectional = bidirectional; - opDesc.produceCell = true; - - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(input)); - MPSGraphTensor* stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[0])); - MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1])); - std::vector inputTensors = {inputTensor, stateTensor, cellStateTensor,}; + MPSGraphLSTMDescriptor* opDesc = [MPSGraphLSTMDescriptor descriptor]; + opDesc.training = true; + opDesc.bidirectional = bidirectional; + opDesc.produceCell = true; + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(input)); + MPSGraphTensor* stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[0])); + MPSGraphTensor* cellStateTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1])); + std::vector inputTensors = { + inputTensor, + stateTensor, + cellStateTensor, + }; + + if (batch_first) { + inputTensor = [mpsGraph transposeTensor:inputTensor dimension:0 withDimension:1 name:nil]; + } - if (batch_first) { - inputTensor = [mpsGraph transposeTensor:inputTensor - dimension:0 - withDimension:1 + MPSGraphTensor* inputTensor_ = inputTensor; + NSArray* outputs = nil; + NSMutableArray* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* outputCellStateFwdArray = + [[NSMutableArray alloc] initWithCapacity:num_layers]; + for (int i = 0; i < num_layers; i++) { + auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, + stateTensor, + cellStateTensor, + recurrentKernelWeightsList, + kernelWeightsList, + kernelBiasList, + recurrentBiasList, + has_biases, + bidirectional, + i); + MPSGraphTensor *stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData); + MPSGraphTensor *recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData); + MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData); + + outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_ + recurrentWeight:recurrentWeight_ + inputWeight:inputWeight_ + bias:biasTensor_ + initState:stateTensor_ + initCell:cellStateTensor_ + descriptor:opDesc name:nil]; + + inputTensor_ = [outputs objectAtIndex:0]; + // no need to keep the final layer output copy as it is + // returned anyway and not used in backprop + if (i != num_layers - 1) { + [layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_ axis:0 name:nil]]; + } + if (dropout_p > 0.0 && train && (i != num_layers - 1)) { + inputTensor_ = [mpsGraph dropoutTensor:inputTensor_ rate:dropout_p name:nil]; } - MPSGraphTensor* inputTensor_ = inputTensor; - NSArray* outputs = nil; - NSMutableArray* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - for (int i = 0; i < num_layers; i++) { - auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor, - recurrentKernelWeightsList, kernelWeightsList, - kernelBiasList, recurrentBiasList, has_biases, - bidirectional, i); - MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData); - MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData); - MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData); - - - outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_ - recurrentWeight:recurrentWeight_ - inputWeight:inputWeight_ - bias:biasTensor_ - initState:stateTensor_ - initCell:cellStateTensor_ - descriptor:opDesc + if (bidirectional) { + // [1, N, 2 * H] + auto stateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]; + auto stateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:0 length:1 name:nil]; + // [1, N, H] ([1, N, 0:H]) + auto stateForward = [mpsGraph sliceTensor:stateLastT + dimension:-1 + start:0 + length:hx[0].sizes()[2] name:nil]; + // [1, N, H] ([1, N, H:2H]) + auto stateBack = [mpsGraph sliceTensor:stateFirstT + dimension:-1 + start:hx[0].sizes()[2] + length:hx[0].sizes()[2] + name:nil]; + [outputStateArray addObject:stateForward]; + [outputStateArray addObject:stateBack]; - inputTensor_ = [outputs objectAtIndex:0]; - // no need to keep the final layer output copy as it is - // returned anyway and not used in backprop - if (i != num_layers - 1) { - [layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_ - axis:0 - name:nil]]; - } - if (dropout_p>0.0 && train && (i!=num_layers-1)) { - inputTensor_ = [mpsGraph dropoutTensor:inputTensor_ - rate:dropout_p + auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] + dimension:0 + start:-1 + length:1 + name:nil]; + auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] + dimension:0 + start:0 + length:1 name:nil]; - - } - - if (bidirectional) { - // [1, N, 2 * H] - auto stateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]; - auto stateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:0 length:1 name:nil]; - // [1, N, H] ([1, N, 0:H]) - auto stateForward = [mpsGraph sliceTensor:stateLastT dimension: -1 start:0 length:hx[0].sizes()[2] name:nil]; - // [1, N, H] ([1, N, H:2H]) - auto stateBack = [mpsGraph sliceTensor:stateFirstT dimension: -1 start:hx[0].sizes()[2] length:hx[0].sizes()[2] name:nil]; - [outputStateArray addObject:stateForward]; - [outputStateArray addObject:stateBack]; - - auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]; - auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:0 length:1 name:nil]; - auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT dimension: -1 start:0 length:hx[1].sizes()[2] name:nil]; - auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT dimension: -1 start:hx[1].sizes()[2] length:hx[1].sizes()[2] name:nil]; - [outputCellStateArray addObject:cellStateForward]; - [outputCellStateArray addObject:cellStateBack]; - } else { - [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]]; - [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]]; - } - [outputCellStateFwdArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1] - axis:0 - name:nil]]; - [outputZStateArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:2] - axis:0 - name:nil]]; - } - - MPSGraphTensor* outputTensor = inputTensor_; - if (batch_first) { - outputTensor = [mpsGraph transposeTensor:outputTensor - dimension:0 - withDimension:1 + auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT + dimension:-1 + start:0 + length:hx[1].sizes()[2] + name:nil]; + auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT + dimension:-1 + start:hx[1].sizes()[2] + length:hx[1].sizes()[2] name:nil]; + [outputCellStateArray addObject:cellStateForward]; + [outputCellStateArray addObject:cellStateBack]; + } else { + [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] + dimension:0 + start:-1 + length:1 + name:nil]]; + [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] + dimension:0 + start:-1 + length:1 + name:nil]]; } - MPSGraphTensor* outputStates = [mpsGraph concatTensors:outputStateArray - dimension:0 - name:nil]; - MPSGraphTensor* outputCellStates = [mpsGraph concatTensors:outputCellStateArray - dimension:0 - name:nil]; - MPSGraphTensor* outputZStates = [mpsGraph concatTensors:outputZStateArray - dimension:0 - name:nil]; - MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray - dimension:0 - name:nil]; - MPSGraphTensor* layersOutputs = (num_layers > 1) - ? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil] - : nil; - - std::vector outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs}; - newCachedGraph->inputTensors_ = inputTensors; - newCachedGraph->outputTensors_ = outputTensors; - newCachedGraph->kernelWeightsList_ = kernelWeightsList; - newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList; - newCachedGraph->biasList_ = kernelBiasList; - newCachedGraph->recurrentBiasList_ = recurrentBiasList; + [outputCellStateFwdArray addObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1] axis:0 name:nil]]; + [outputZStateArray addObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:2] axis:0 name:nil]]; } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - NSMutableArray *kernelWeightsList = cachedGraph->kernelWeightsList_; - NSMutableArray *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_; - NSMutableArray *biasList = cachedGraph->biasList_; - NSMutableArray *recurrentBiasList = cachedGraph->recurrentBiasList_; - - NSMutableDictionary *feeds = [[[NSMutableDictionary alloc] init] autorelease]; - for (const auto i : c10::irange(total_layers)) { - Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); - Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); - [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()]; - [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()]; - if (has_biases) { - Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]); - Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); - [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()]; - [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()]; + MPSGraphTensor* outputTensor = inputTensor_; + if (batch_first) { + outputTensor = [mpsGraph transposeTensor:outputTensor dimension:0 withDimension:1 name:nil]; } - } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input); - Placeholder selfState = Placeholder(cachedGraph->inputTensors_[1], hx[0]); - Placeholder selfCellState = Placeholder(cachedGraph->inputTensors_[2], hx[1]); - [feeds setObject:selfPlaceholder.getMPSGraphTensorData() forKey:selfPlaceholder.getMPSGraphTensor()]; - [feeds setObject:selfState.getMPSGraphTensorData() forKey:selfState.getMPSGraphTensor()]; - [feeds setObject:selfCellState.getMPSGraphTensorData() forKey:selfCellState.getMPSGraphTensor()]; - - - auto dims = getTensorShape(cachedGraph->outputTensors_[0]); - Tensor output = at::empty(IntArrayRef(dims), input.options()); - Tensor hy = at::empty_like(hx[0], input.options()); - Tensor cy = at::empty_like(hx[1], input.options()); - Tensor zState = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[3])), input.options()); - Tensor cellStateFwd = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[4])), input.options()); - Tensor layerOutputs = (num_layers > 1) - ? at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[5])), input.options()) - : at::empty({ 1 }, input.options()); // not used if num_layers == 1 - - Placeholder outputPlaceholder0 = Placeholder(cachedGraph->outputTensors_[0], output); - Placeholder outputPlaceholder1 = Placeholder(cachedGraph->outputTensors_[1], hy); - Placeholder outputPlaceholder2 = Placeholder(cachedGraph->outputTensors_[2], cy); - Placeholder outputPlaceholder3 = Placeholder(cachedGraph->outputTensors_[3], zState); - Placeholder outputPlaceholder4 = Placeholder(cachedGraph->outputTensors_[4], cellStateFwd); - - NSMutableDictionary* results = [@{ - outputPlaceholder0.getMPSGraphTensor() : outputPlaceholder0.getMPSGraphTensorData(), - outputPlaceholder1.getMPSGraphTensor() : outputPlaceholder1.getMPSGraphTensorData(), - outputPlaceholder2.getMPSGraphTensor() : outputPlaceholder2.getMPSGraphTensorData(), - outputPlaceholder3.getMPSGraphTensor() : outputPlaceholder3.getMPSGraphTensorData(), - outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData(), - } mutableCopy]; - - if (num_layers > 1) { - Placeholder outputPlaceholder5 = Placeholder(cachedGraph->outputTensors_[5], layerOutputs); - [results setObject:outputPlaceholder5.getMPSGraphTensorData() forKey: outputPlaceholder5.getMPSGraphTensor()]; - } + MPSGraphTensor* outputStates = [mpsGraph concatTensors:outputStateArray dimension:0 name:nil]; + MPSGraphTensor* outputCellStates = [mpsGraph concatTensors:outputCellStateArray dimension:0 name:nil]; + MPSGraphTensor* outputZStates = [mpsGraph concatTensors:outputZStateArray dimension:0 name:nil]; + MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray dimension:0 name:nil]; + MPSGraphTensor* layersOutputs = + (num_layers > 1) ? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil] : nil; + + std::vector outputTensors = { + outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs}; + newCachedGraph->inputTensors_ = inputTensors; + newCachedGraph->outputTensors_ = outputTensors; + newCachedGraph->kernelWeightsList_ = kernelWeightsList; + newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList; + newCachedGraph->biasList_ = kernelBiasList; + newCachedGraph->recurrentBiasList_ = recurrentBiasList; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + NSMutableArray* kernelWeightsList = cachedGraph->kernelWeightsList_; + NSMutableArray* recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_; + NSMutableArray* biasList = cachedGraph->biasList_; + NSMutableArray* recurrentBiasList = cachedGraph->recurrentBiasList_; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - return std::make_tuple(output, hy, cy, zState, cellStateFwd, layerOutputs); + NSMutableDictionary* feeds = [[[NSMutableDictionary alloc] init] autorelease]; + for (const auto i : c10::irange(total_layers)) { + Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); + Placeholder recurrentKernelWeight = + Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); + [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()]; + [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()]; + if (has_biases) { + Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]); + Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); + [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()]; + [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()]; + } + } + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input); + Placeholder selfState = Placeholder(cachedGraph->inputTensors_[1], hx[0]); + Placeholder selfCellState = Placeholder(cachedGraph->inputTensors_[2], hx[1]); + [feeds setObject:selfPlaceholder.getMPSGraphTensorData() forKey:selfPlaceholder.getMPSGraphTensor()]; + [feeds setObject:selfState.getMPSGraphTensorData() forKey:selfState.getMPSGraphTensor()]; + [feeds setObject:selfCellState.getMPSGraphTensorData() forKey:selfCellState.getMPSGraphTensor()]; + + auto dims = getTensorShape(cachedGraph->outputTensors_[0]); + Tensor output = at::empty(IntArrayRef(dims), input.options()); + Tensor hy = at::empty_like(hx[0], input.options()); + Tensor cy = at::empty_like(hx[1], input.options()); + Tensor zState = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[3])), input.options()); + Tensor cellStateFwd = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[4])), input.options()); + Tensor layerOutputs = (num_layers > 1) + ? at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[5])), input.options()) + : at::empty({1}, input.options()); // not used if num_layers == 1 + + Placeholder outputPlaceholder0 = Placeholder(cachedGraph->outputTensors_[0], output); + Placeholder outputPlaceholder1 = Placeholder(cachedGraph->outputTensors_[1], hy); + Placeholder outputPlaceholder2 = Placeholder(cachedGraph->outputTensors_[2], cy); + Placeholder outputPlaceholder3 = Placeholder(cachedGraph->outputTensors_[3], zState); + Placeholder outputPlaceholder4 = Placeholder(cachedGraph->outputTensors_[4], cellStateFwd); + + NSMutableDictionary* results = [@{ + outputPlaceholder0.getMPSGraphTensor() : outputPlaceholder0.getMPSGraphTensorData(), + outputPlaceholder1.getMPSGraphTensor() : outputPlaceholder1.getMPSGraphTensorData(), + outputPlaceholder2.getMPSGraphTensor() : outputPlaceholder2.getMPSGraphTensorData(), + outputPlaceholder3.getMPSGraphTensor() : outputPlaceholder3.getMPSGraphTensorData(), + outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData(), + } mutableCopy]; + + if (num_layers > 1) { + Placeholder outputPlaceholder5 = Placeholder(cachedGraph->outputTensors_[5], layerOutputs); + [results setObject:outputPlaceholder5.getMPSGraphTensorData() forKey:outputPlaceholder5.getMPSGraphTensor()]; } + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + return std::make_tuple(output, hy, cy, zState, cellStateFwd, layerOutputs); + } } -std::tuple, std::vector> lstm_mps_backward(const Tensor& grad_y, const c10::optional& grad_hy_opt, const c10::optional& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, const Tensor& layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) { - using namespace mps; - const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] {return Tensor();}); - const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] {return Tensor();}); - auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx[0], input.options()); - auto grad_cy = grad_cy_r.defined() ? grad_cy_r : at::zeros_like(hx[1], input.options()); +std::tuple, std::vector> lstm_mps_backward(const Tensor& grad_y, + const c10::optional& grad_hy_opt, + const c10::optional& grad_cy_opt, + const Tensor& z_state, + const Tensor& cell_state_fwd, + const Tensor& input, + const Tensor& layersOutputs, + TensorList hx, + TensorList params, + bool has_biases, + int64_t num_layers, + double dropout_p, + bool train, + bool bidirectional, + bool batch_first) { + using namespace mps; + const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] { return Tensor(); }); + const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] { return Tensor(); }); + auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx[0], input.options()); + auto grad_cy = grad_cy_r.defined() ? grad_cy_r : at::zeros_like(hx[1], input.options()); + + std::vector kernel_weights; + std::vector recurrent_kernel_weights; + std::vector biases; + std::vector recurrent_biases; + + const int64_t total_layers = num_layers * (bidirectional ? 2 : 1); + + for (const auto i : c10::irange(total_layers)) { + const int stride = (has_biases ? 4 : 2); + kernel_weights.push_back(params[i * stride]); + recurrent_kernel_weights.push_back(params[i * stride + 1]); + if (has_biases) { + biases.push_back(params[i * stride + 2]); + recurrent_biases.push_back(params[i * stride + 3]); + } + } + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + std::vector inputTensors_; + NSMutableArray* kernelWeightsList_ = nil; + NSMutableArray* recurrentKernelWeightsList_ = nil; + NSMutableArray* biasList_ = nil; + NSMutableArray* recurrentBiasList_ = nil; + NSMutableArray* gradRecWeights_ = nil; + NSMutableArray* gradWeights_ = nil; + NSMutableArray* gradBias_ = nil; + MPSGraphTensor* gradOutput_ = nil; + MPSGraphTensor* gradState_ = nil; + MPSGraphTensor* gradCellState_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + // Get stream + MPSStream* stream = getCurrentMPSStream(); + @autoreleasepool { + string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy}) + + getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" + + std::to_string(batch_first); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + NSMutableArray* kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* recurrentKernelWeightsList = + [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; + NSMutableArray* recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; + + for (const auto i : c10::irange(total_layers)) { + [kernelWeightsList + addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))]; + [recurrentKernelWeightsList + addObject:mpsGraphRankedPlaceHolder( + mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))]; + if (has_biases) { + [kernelBiasList + addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))]; + [recurrentBiasList addObject:mpsGraphRankedPlaceHolder( + mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))]; + } + } + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(input)); + MPSGraphTensor* stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[0])); + MPSGraphTensor* cellStateTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1])); + MPSGraphTensor* zStateTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(z_state)); + MPSGraphTensor* gradientTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_y), getMPSShape(grad_y)); + MPSGraphTensor* gradientCyTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy), getMPSShape(grad_cy)); + MPSGraphTensor* gradientHyTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy), getMPSShape(grad_hy)); + MPSGraphTensor* cellStateFwdTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd), getMPSShape(cell_state_fwd)); + MPSGraphTensor* layersOutputsTensor = + mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs), getMPSShape(layersOutputs)); + + std::vector inputs = {inputTensor, + stateTensor, + cellStateTensor, + gradientTensor, + zStateTensor, + cellStateFwdTensor, + gradientHyTensor, + gradientCyTensor, + layersOutputsTensor}; + + if (batch_first) { + inputTensor = [mpsGraph transposeTensor:inputTensor dimension:0 withDimension:1 name:nil]; + + gradientTensor = [mpsGraph transposeTensor:gradientTensor dimension:0 withDimension:1 name:nil]; + } - std::vector kernel_weights; - std::vector recurrent_kernel_weights; - std::vector biases; - std::vector recurrent_biases; + newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList; + newCachedGraph->kernelWeightsList_ = kernelWeightsList; + newCachedGraph->biasList_ = kernelBiasList; + newCachedGraph->recurrentBiasList_ = recurrentBiasList; + newCachedGraph->inputTensors_ = inputs; - const int64_t total_layers = num_layers * (bidirectional ? 2 : 1); + MPSGraphLSTMDescriptor* opDesc = [MPSGraphLSTMDescriptor descriptor]; + opDesc.training = true; // train; + opDesc.bidirectional = bidirectional; + opDesc.produceCell = true; - for (const auto i : c10::irange(total_layers)) { - const int stride = (has_biases ? 4 : 2); - kernel_weights.push_back(params[i*stride]); - recurrent_kernel_weights.push_back(params[i*stride+1]); - if(has_biases) { - biases.push_back(params[i*stride + 2]); - recurrent_biases.push_back(params[i*stride + 3]); - } - } + MPSGraphTensor* gradientTensor_ = gradientTensor; + + NSArray* outputs = nil; + + NSMutableArray* gradRecWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* gradWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* gradBiasArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* gradStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; + NSMutableArray* gradCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - std::vector inputTensors_; - NSMutableArray *kernelWeightsList_ = nil; - NSMutableArray *recurrentKernelWeightsList_ = nil; - NSMutableArray *biasList_ = nil; - NSMutableArray *recurrentBiasList_ = nil; - NSMutableArray *gradRecWeights_ = nil; - NSMutableArray *gradWeights_ = nil; - NSMutableArray *gradBias_ = nil; - MPSGraphTensor* gradOutput_ = nil; - MPSGraphTensor* gradState_ = nil; - MPSGraphTensor* gradCellState_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - // Get stream - MPSStream* stream = getCurrentMPSStream(); - @autoreleasepool { - - string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" + std::to_string(batch_first); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if (!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - NSMutableArray *kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - NSMutableArray *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()]; - - for (const auto i : c10::irange(total_layers)) { - [kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(kernel_weights[i]))]; - [recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_kernel_weights[i]))]; - if (has_biases) { - [kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(biases[i]))]; - [recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(recurrent_biases[i]))]; - } - } - - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(input)); - MPSGraphTensor* stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[0])); - MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(hx[1])); - MPSGraphTensor* zStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), getMPSShape(z_state)); - MPSGraphTensor* gradientTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_y), getMPSShape(grad_y)); - MPSGraphTensor* gradientCyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy), getMPSShape(grad_cy)); - MPSGraphTensor* gradientHyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy), getMPSShape(grad_hy)); - MPSGraphTensor* cellStateFwdTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd), getMPSShape(cell_state_fwd)); - MPSGraphTensor* layersOutputsTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs), getMPSShape(layersOutputs)); - - std::vector inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor, layersOutputsTensor}; - - if (batch_first) { - inputTensor = [mpsGraph transposeTensor: inputTensor - dimension: 0 - withDimension: 1 - name: nil]; - - gradientTensor = [mpsGraph transposeTensor: gradientTensor - dimension: 0 - withDimension: 1 - name: nil]; - } - - newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList; - newCachedGraph->kernelWeightsList_ = kernelWeightsList; - newCachedGraph->biasList_ = kernelBiasList; - newCachedGraph->recurrentBiasList_ = recurrentBiasList; - newCachedGraph->inputTensors_ = inputs; - - MPSGraphLSTMDescriptor * opDesc = [MPSGraphLSTMDescriptor descriptor]; - opDesc.training = true; //train; - opDesc.bidirectional = bidirectional; - opDesc.produceCell = true; - - MPSGraphTensor* gradientTensor_ = gradientTensor; - - NSArray* outputs = nil; - - NSMutableArray* gradRecWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* gradWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* gradBiasArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* gradStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - NSMutableArray* gradCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers]; - - auto hidden_size = hx[0].sizes()[2]; - - for (int i = num_layers - 1; i >= 0; i--) { - MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor - dimension:0 - start:i - length:1 - name:nil]; - zState = [mpsGraph squeezeTensor:zState - axis:0 + auto hidden_size = hx[0].sizes()[2]; + + for (int i = num_layers - 1; i >= 0; i--) { + MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor dimension:0 start:i length:1 name:nil]; + zState = [mpsGraph squeezeTensor:zState axis:0 name:nil]; + MPSGraphTensor* cellStateFwd = [mpsGraph sliceTensor:cellStateFwdTensor + dimension:0 + start:i + length:1 + name:nil]; + cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd axis:0 name:nil]; + auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, + stateTensor, + cellStateTensor, + recurrentKernelWeightsList, + kernelWeightsList, + kernelBiasList, + recurrentBiasList, + has_biases, + bidirectional, + i); + MPSGraphTensor *stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData); + MPSGraphTensor *recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData); + MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData); + + MPSGraphTensor *gradientHyTensor_ = nil, *gradientCyTensor_ = nil; + if (bidirectional) { + gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor dimension:0 start:i * 2 length:2 name:nil]; + // [2, N, H] -> [N, 2, H] + gradientHyTensor_ = [mpsGraph transposeTensor:gradientHyTensor_ dimension:0 withDimension:1 name:nil]; + // [N, 2, H] -> [N, 2 * H] + gradientHyTensor_ = [mpsGraph flatten2DTensor:gradientHyTensor_ axis:1 name:nil]; + + gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor dimension:0 start:i * 2 length:2 name:nil]; + gradientCyTensor_ = [mpsGraph transposeTensor:gradientCyTensor_ dimension:0 withDimension:1 name:nil]; + gradientCyTensor_ = [mpsGraph flatten2DTensor:gradientCyTensor_ axis:1 name:nil]; + } else { + gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor dimension:0 start:i length:1 name:nil]; + + gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor dimension:0 start:i length:1 name:nil]; + } + + MPSGraphTensor* iterationInputTensor_ = nil; + if (i == 0) { + iterationInputTensor_ = inputTensor; + } else { + iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor + dimension:0 + // the last element in layersOutputsTensor + // contains **inputs** for the **last** layer + // and so on + start:i - num_layers + length:1 + name:nil]; + iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_ axis:0 name:nil]; + } + + outputs = [mpsGraph LSTMGradientsWithSourceTensor:iterationInputTensor_ + recurrentWeight:recurrentWeight_ + sourceGradient:gradientTensor_ + zState:zState + cellOutputFwd:cellStateFwd + stateGradient:gradientHyTensor_ + cellGradient:gradientCyTensor_ + inputWeight:inputWeight_ + bias:biasTensor_ + initState:stateTensor_ + initCell:cellStateTensor_ + mask:nil + peephole:nil + descriptor:opDesc + name:nil]; + + gradientTensor_ = [outputs objectAtIndex:0]; + if (bidirectional) { + int outputIter = 1; + auto gradRecWeightsBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradRecWeightFwd = [mpsGraph sliceTensor:gradRecWeightsBidirectional + dimension:0 + start:0 + length:1 + name:nil]; + gradRecWeightFwd = [mpsGraph squeezeTensor:gradRecWeightFwd axis:0 name:nil]; + auto gradRecWeightBack = [mpsGraph sliceTensor:gradRecWeightsBidirectional + dimension:0 + start:1 + length:1 + name:nil]; + gradRecWeightBack = [mpsGraph squeezeTensor:gradRecWeightBack axis:0 name:nil]; + + // inverse order + [gradRecWeightsArray insertObject:gradRecWeightBack atIndex:0]; + [gradRecWeightsArray insertObject:gradRecWeightFwd atIndex:0]; + + auto gradWeightsBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradWeightFwd = [mpsGraph sliceTensor:gradWeightsBidirectional + dimension:0 + start:0 + length:hidden_size * 4 + name:nil]; + auto gradWeightBack = [mpsGraph sliceTensor:gradWeightsBidirectional + dimension:0 + start:hidden_size * 4 + length:hidden_size * 4 + name:nil]; + + [gradWeightsArray insertObject:gradWeightBack atIndex:0]; + [gradWeightsArray insertObject:gradWeightFwd atIndex:0]; + + if (has_biases) { + // has shape [1, 1, 8H] vs [8H] as should be + // so, squeeze these two first dimensions + auto gradBiasBidirectional = [outputs objectAtIndex:outputIter++]; + gradBiasBidirectional = [mpsGraph squeezeTensor:gradBiasBidirectional axes:@[ @0, @1 ] name:nil]; + auto gradBiasFwd = [mpsGraph sliceTensor:gradBiasBidirectional + dimension:0 + start:0 + length:hidden_size * 4 name:nil]; - MPSGraphTensor* cellStateFwd = [mpsGraph sliceTensor:cellStateFwdTensor - dimension:0 - start:i - length:1 - name:nil]; - cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd - axis:0 + auto gradBiasBack = [mpsGraph sliceTensor:gradBiasBidirectional + dimension:0 + start:hidden_size * 4 + length:hidden_size * 4 + name:nil]; + + [gradBiasArray insertObject:gradBiasBack atIndex:0]; + [gradBiasArray insertObject:gradBiasFwd atIndex:0]; + } + + auto gradStateBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradStateFwd = [mpsGraph sliceTensor:gradStateBidirectional + dimension:1 + start:0 + length:hidden_size + name:nil]; + auto gradStateBack = [mpsGraph sliceTensor:gradStateBidirectional + dimension:1 + start:hidden_size + length:hidden_size name:nil]; - auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor, - recurrentKernelWeightsList, kernelWeightsList, - kernelBiasList, recurrentBiasList, has_biases, - bidirectional, i); - MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData); - MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData); - MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData); - - MPSGraphTensor* gradientHyTensor_ = nil, *gradientCyTensor_ = nil; - if (bidirectional) { - gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor - dimension:0 - start:i * 2 - length:2 - name:nil]; - // [2, N, H] -> [N, 2, H] - gradientHyTensor_ = [mpsGraph transposeTensor:gradientHyTensor_ dimension: 0 withDimension: 1 name:nil]; - // [N, 2, H] -> [N, 2 * H] - gradientHyTensor_ = [mpsGraph flatten2DTensor:gradientHyTensor_ axis:1 name:nil]; - - - gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor - dimension:0 - start:i * 2 - length:2 - name:nil]; - gradientCyTensor_ = [mpsGraph transposeTensor:gradientCyTensor_ dimension: 0 withDimension: 1 name:nil]; - gradientCyTensor_ = [mpsGraph flatten2DTensor:gradientCyTensor_ axis:1 name:nil]; - } else { - gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor - dimension:0 - start:i - length:1 - name:nil]; - - gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor - dimension:0 - start:i - length:1 - name:nil]; - } - - MPSGraphTensor* iterationInputTensor_ = nil; - if (i == 0) { - iterationInputTensor_ = inputTensor; - } else { - iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor - dimension: 0 - // the last element in layersOutputsTensor - // contains **inputs** for the **last** layer - // and so on - start: i - num_layers - length: 1 - name: nil]; - iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_ - axis:0 - name: nil]; - } - - outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_ - recurrentWeight: recurrentWeight_ - sourceGradient: gradientTensor_ - zState: zState - cellOutputFwd: cellStateFwd - stateGradient: gradientHyTensor_ - cellGradient: gradientCyTensor_ - inputWeight: inputWeight_ - bias: biasTensor_ - initState: stateTensor_ - initCell: cellStateTensor_ - mask: nil - peephole: nil - descriptor: opDesc - name: nil]; - - gradientTensor_ = [outputs objectAtIndex:0]; - if (bidirectional) { - int outputIter = 1; - auto gradRecWeightsBidirectional = [outputs objectAtIndex:outputIter++]; - auto gradRecWeightFwd = [mpsGraph sliceTensor:gradRecWeightsBidirectional - dimension: 0 - start: 0 - length: 1 - name: nil]; - gradRecWeightFwd = [mpsGraph squeezeTensor:gradRecWeightFwd axis:0 name: nil]; - auto gradRecWeightBack = [mpsGraph sliceTensor:gradRecWeightsBidirectional - dimension: 0 - start: 1 - length: 1 - name: nil]; - gradRecWeightBack = [mpsGraph squeezeTensor:gradRecWeightBack axis:0 name: nil]; - - // inverse order - [gradRecWeightsArray insertObject:gradRecWeightBack atIndex:0]; - [gradRecWeightsArray insertObject:gradRecWeightFwd atIndex:0]; - - auto gradWeightsBidirectional = [outputs objectAtIndex:outputIter++]; - auto gradWeightFwd = [mpsGraph sliceTensor:gradWeightsBidirectional - dimension: 0 - start: 0 - length: hidden_size * 4 - name: nil]; - auto gradWeightBack = [mpsGraph sliceTensor:gradWeightsBidirectional - dimension: 0 - start: hidden_size * 4 - length: hidden_size * 4 - name: nil]; - - [gradWeightsArray insertObject:gradWeightBack atIndex:0]; - [gradWeightsArray insertObject:gradWeightFwd atIndex:0]; - - if (has_biases) { - // has shape [1, 1, 8H] vs [8H] as should be - // so, squeeze these two first dimensions - auto gradBiasBidirectional = [outputs objectAtIndex:outputIter++]; - gradBiasBidirectional = [mpsGraph squeezeTensor: gradBiasBidirectional - axes: @[@0, @1] - name: nil]; - auto gradBiasFwd = [mpsGraph sliceTensor:gradBiasBidirectional - dimension: 0 - start: 0 - length: hidden_size * 4 - name: nil]; - auto gradBiasBack = [mpsGraph sliceTensor:gradBiasBidirectional - dimension: 0 - start: hidden_size * 4 - length: hidden_size * 4 - name: nil]; - - [gradBiasArray insertObject: gradBiasBack atIndex:0]; - [gradBiasArray insertObject: gradBiasFwd atIndex:0]; - } - - auto gradStateBidirectional = [outputs objectAtIndex:outputIter++]; - auto gradStateFwd = [mpsGraph sliceTensor:gradStateBidirectional - dimension: 1 - start: 0 - length: hidden_size - name: nil]; - auto gradStateBack = [mpsGraph sliceTensor:gradStateBidirectional - dimension: 1 - start: hidden_size - length: hidden_size - name: nil]; - - [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateBack axis:0 name:nil] atIndex:0]; - [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateFwd axis:0 name:nil] atIndex:0]; - - auto gradCellStateBidirectional = [outputs objectAtIndex:outputIter++]; - auto gradCellStateFwd = [mpsGraph sliceTensor:gradCellStateBidirectional - dimension: 1 - start: 0 - length: hidden_size - name: nil]; - auto gradCellStateBack = [mpsGraph sliceTensor:gradCellStateBidirectional - dimension: 1 - start: hidden_size - length: hidden_size - name: nil]; - - [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil] atIndex:0]; - [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil] atIndex:0]; - } else { - int outputIter = 1; - [gradRecWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; - [gradWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; - if (has_biases) { - [gradBiasArray insertObject: [outputs objectAtIndex:outputIter++] atIndex:0]; - } - [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0]; - [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0]; - } - } - if (batch_first) { - MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_ - dimension: 0 - withDimension: 1 - name:nil]; - newCachedGraph->gradOutput_ = gradientTensorTransposed; - } else { - newCachedGraph->gradOutput_ = gradientTensor_; - } - - newCachedGraph->gradRecWeights_ = gradRecWeightsArray; - newCachedGraph->gradWeights_ = gradWeightsArray; - newCachedGraph->gradBias_ = gradBiasArray; - newCachedGraph->gradState_ = [mpsGraph concatTensors:gradStateArray dimension: 0 name: nil]; - newCachedGraph->gradCellState_ = [mpsGraph concatTensors:gradCellStateArray dimension: 0 name: nil]; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input); - Placeholder statePlaceholder = Placeholder(cachedGraph->inputTensors_[1], hx[0]); - Placeholder cellStatePlaceholder = Placeholder(cachedGraph->inputTensors_[2], hx[1]); - Placeholder gradientPlaceholder = Placeholder(cachedGraph->inputTensors_[3], grad_y); - Placeholder zStatePlaceholder = Placeholder(cachedGraph->inputTensors_[4], z_state); - Placeholder cellStateFwdPlaceholder = Placeholder(cachedGraph->inputTensors_[5], cell_state_fwd); - Placeholder gradientHyPlaceholder = Placeholder(cachedGraph->inputTensors_[6], grad_hy); - Placeholder gradientCyPlaceholder = Placeholder(cachedGraph->inputTensors_[7], grad_cy); - Placeholder layersOutputsPlaceholder = Placeholder(cachedGraph->inputTensors_[8], layersOutputs); - - NSMutableDictionary *feeds = [[[NSMutableDictionary alloc] init] autorelease]; - [feeds setObject:gradientPlaceholder.getMPSGraphTensorData() forKey:gradientPlaceholder.getMPSGraphTensor()]; - [feeds setObject:gradientHyPlaceholder.getMPSGraphTensorData() forKey:gradientHyPlaceholder.getMPSGraphTensor()]; - [feeds setObject:gradientCyPlaceholder.getMPSGraphTensorData() forKey:gradientCyPlaceholder.getMPSGraphTensor()]; - [feeds setObject:inputPlaceholder.getMPSGraphTensorData() forKey:inputPlaceholder.getMPSGraphTensor()]; - [feeds setObject:statePlaceholder.getMPSGraphTensorData() forKey: statePlaceholder.getMPSGraphTensor()]; - [feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()]; - [feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()]; - [feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() forKey:cellStateFwdPlaceholder.getMPSGraphTensor()]; - [feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() forKey:layersOutputsPlaceholder.getMPSGraphTensor()]; - - NSMutableArray *kernelWeightsList = cachedGraph->kernelWeightsList_; - NSMutableArray *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_; - NSMutableArray *biasList = cachedGraph->biasList_; - NSMutableArray *recurrentBiasList = cachedGraph->recurrentBiasList_; - - for (const auto i : c10::irange(total_layers)) { - Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); - Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); - [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()]; - [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()]; - if (has_biases) { - Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]); - Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); - [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()]; - [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()]; + [gradStateArray insertObject:[mpsGraph expandDimsOfTensor:gradStateBack axis:0 name:nil] atIndex:0]; + [gradStateArray insertObject:[mpsGraph expandDimsOfTensor:gradStateFwd axis:0 name:nil] atIndex:0]; + + auto gradCellStateBidirectional = [outputs objectAtIndex:outputIter++]; + auto gradCellStateFwd = [mpsGraph sliceTensor:gradCellStateBidirectional + dimension:1 + start:0 + length:hidden_size + name:nil]; + auto gradCellStateBack = [mpsGraph sliceTensor:gradCellStateBidirectional + dimension:1 + start:hidden_size + length:hidden_size + name:nil]; + + [gradCellStateArray insertObject:[mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil] + atIndex:0]; + [gradCellStateArray insertObject:[mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil] + atIndex:0]; + } else { + int outputIter = 1; + [gradRecWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; + [gradWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; + if (has_biases) { + [gradBiasArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0]; + } + [gradStateArray insertObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] + axis:0 + name:nil] + atIndex:0]; + [gradCellStateArray insertObject:[mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] + axis:0 + name:nil] + atIndex:0]; } + } + if (batch_first) { + MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_ + dimension:0 + withDimension:1 + name:nil]; + newCachedGraph->gradOutput_ = gradientTensorTransposed; + } else { + newCachedGraph->gradOutput_ = gradientTensor_; + } + + newCachedGraph->gradRecWeights_ = gradRecWeightsArray; + newCachedGraph->gradWeights_ = gradWeightsArray; + newCachedGraph->gradBias_ = gradBiasArray; + newCachedGraph->gradState_ = [mpsGraph concatTensors:gradStateArray dimension:0 name:nil]; + newCachedGraph->gradCellState_ = [mpsGraph concatTensors:gradCellStateArray dimension:0 name:nil]; } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } - Tensor output_out = at::empty_like(input); - Tensor grad_state_out = at::empty_like(hx[0]); - Tensor grad_cell_state_out = at::empty_like(hx[1]); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input); + Placeholder statePlaceholder = Placeholder(cachedGraph->inputTensors_[1], hx[0]); + Placeholder cellStatePlaceholder = Placeholder(cachedGraph->inputTensors_[2], hx[1]); + Placeholder gradientPlaceholder = Placeholder(cachedGraph->inputTensors_[3], grad_y); + Placeholder zStatePlaceholder = Placeholder(cachedGraph->inputTensors_[4], z_state); + Placeholder cellStateFwdPlaceholder = Placeholder(cachedGraph->inputTensors_[5], cell_state_fwd); + Placeholder gradientHyPlaceholder = Placeholder(cachedGraph->inputTensors_[6], grad_hy); + Placeholder gradientCyPlaceholder = Placeholder(cachedGraph->inputTensors_[7], grad_cy); + Placeholder layersOutputsPlaceholder = Placeholder(cachedGraph->inputTensors_[8], layersOutputs); + + NSMutableDictionary* feeds = [[[NSMutableDictionary alloc] init] autorelease]; + [feeds setObject:gradientPlaceholder.getMPSGraphTensorData() forKey:gradientPlaceholder.getMPSGraphTensor()]; + [feeds setObject:gradientHyPlaceholder.getMPSGraphTensorData() forKey:gradientHyPlaceholder.getMPSGraphTensor()]; + [feeds setObject:gradientCyPlaceholder.getMPSGraphTensorData() forKey:gradientCyPlaceholder.getMPSGraphTensor()]; + [feeds setObject:inputPlaceholder.getMPSGraphTensorData() forKey:inputPlaceholder.getMPSGraphTensor()]; + [feeds setObject:statePlaceholder.getMPSGraphTensorData() forKey:statePlaceholder.getMPSGraphTensor()]; + [feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()]; + [feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()]; + [feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() + forKey:cellStateFwdPlaceholder.getMPSGraphTensor()]; + [feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() + forKey:layersOutputsPlaceholder.getMPSGraphTensor()]; + + NSMutableArray* kernelWeightsList = cachedGraph->kernelWeightsList_; + NSMutableArray* recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_; + NSMutableArray* biasList = cachedGraph->biasList_; + NSMutableArray* recurrentBiasList = cachedGraph->recurrentBiasList_; + for (const auto i : c10::irange(total_layers)) { + Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]); + Placeholder recurrentKernelWeight = + Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]); + [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()]; + [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()]; + if (has_biases) { + Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]); + Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]); + [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()]; + [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()]; + } + } - std::vector grad_hx = {grad_state_out, grad_cell_state_out}; + Tensor output_out = at::empty_like(input); + Tensor grad_state_out = at::empty_like(hx[0]); + Tensor grad_cell_state_out = at::empty_like(hx[1]); - NSMutableDictionary *results = [[[NSMutableDictionary alloc] init] autorelease]; - NSMutableArray *gradRecWeightsArray = cachedGraph->gradRecWeights_; - NSMutableArray *gradWeightsArray = cachedGraph->gradWeights_; - NSMutableArray *gradBiasArray = cachedGraph->gradBias_; - MPSGraphTensor* gradOutput = cachedGraph->gradOutput_; - MPSGraphTensor* gradState = cachedGraph->gradState_; - MPSGraphTensor* gradCellState = cachedGraph->gradCellState_; + std::vector grad_hx = {grad_state_out, grad_cell_state_out}; - Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out); - Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out); - Placeholder outputPlaceholder = Placeholder(gradOutput, output_out); - [results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()]; - [results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()]; - [results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()]; + NSMutableDictionary* results = + [[[NSMutableDictionary alloc] init] autorelease]; + NSMutableArray* gradRecWeightsArray = cachedGraph->gradRecWeights_; + NSMutableArray* gradWeightsArray = cachedGraph->gradWeights_; + NSMutableArray* gradBiasArray = cachedGraph->gradBias_; + MPSGraphTensor* gradOutput = cachedGraph->gradOutput_; + MPSGraphTensor* gradState = cachedGraph->gradState_; + MPSGraphTensor* gradCellState = cachedGraph->gradCellState_; - Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder; + Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out); + Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out); + Placeholder outputPlaceholder = Placeholder(gradOutput, output_out); + [results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()]; + [results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() + forKey:gradCellStatePlaceholder.getMPSGraphTensor()]; + [results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()]; - std::vector weights; - for (const auto i : c10::irange(total_layers)) { - Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]); - Tensor grad_weights = at::empty_like(kernel_weights[i]); + Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder; - weights.push_back(grad_weights); - weights.push_back(grad_rec_weights); + std::vector weights; + for (const auto i : c10::irange(total_layers)) { + Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]); + Tensor grad_weights = at::empty_like(kernel_weights[i]); - gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights); - gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights); + weights.push_back(grad_weights); + weights.push_back(grad_rec_weights); - [results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()]; - [results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()]; + gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:i], grad_rec_weights); + gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:i], grad_weights); - if (has_biases) { - Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options()); + [results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() + forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()]; + [results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() + forKey:gradWeightsPlaceholder.getMPSGraphTensor()]; - // In PyTorch LSTM API there are two biases. The second bias is included for CuDNN compatibility. - // In this implementation these two biases are added together and used further. - // Therefore, they have equal gradient, and it is pushed - // twice for each of two bias vectors. - weights.push_back(grad_bias); - weights.push_back(grad_bias); + if (has_biases) { + Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options()); - gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias); - [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()]; - } - } + // In PyTorch LSTM API there are two biases. The second bias is included for CuDNN compatibility. + // In this implementation these two biases are added together and used further. + // Therefore, they have equal gradient, and it is pushed + // twice for each of two bias vectors. + weights.push_back(grad_bias); + weights.push_back(grad_bias); - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:i], grad_bias); + [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()]; + } + } - return std::tuple, std::vector> (output_out, grad_hx, weights); + runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + return std::tuple, std::vector>(output_out, grad_hx, weights); + } } -} //namespace at::native +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Scalar.mm b/aten/src/ATen/native/mps/operations/Scalar.mm index 73e099d14765a9..3d35343b0ac2a2 100644 --- a/aten/src/ATen/native/mps/operations/Scalar.mm +++ b/aten/src/ATen/native/mps/operations/Scalar.mm @@ -1,13 +1,13 @@ // Copyright © 2022 Apple Inc. #include +#include #include #include -#include #include -#include #include +#include #include #ifdef __OBJC__ @@ -21,17 +21,20 @@ Scalar _local_scalar_dense_mps(const Tensor& self) { Scalar r; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_mps", [&] { - Tensor output = at::empty_like(self, kCPU); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + self.scalar_type(), + "_local_scalar_dense_mps", + [&] { + Tensor output = at::empty_like(self, kCPU); - Tensor cpu_output = mps::mps_copy_(output, self, false); - scalar_t value = *cpu_output.data_ptr(); - r = Scalar(value); - }); + Tensor cpu_output = mps::mps_copy_(output, self, false); + scalar_t value = *cpu_output.data_ptr(); + r = Scalar(value); + }); return r; } - } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/ScatterGather.mm b/aten/src/ATen/native/mps/operations/ScatterGather.mm index a4ec90514c7c7b..378dee31453173 100644 --- a/aten/src/ATen/native/mps/operations/ScatterGather.mm +++ b/aten/src/ATen/native/mps/operations/ScatterGather.mm @@ -5,12 +5,7 @@ namespace at::native { TORCH_IMPL_FUNC(gather_out_mps) -(const Tensor & self_arg, - int64_t dim, - const Tensor & index, - bool sparse_grad, - const Tensor & output) -{ +(const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) { using namespace mps; if (self_arg.numel() == 0 || index.numel() == 0) { @@ -20,14 +15,11 @@ dim = at::maybe_wrap_dim(dim, self.dim()); TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet") - TORCH_CHECK(self.scalar_type() == output.scalar_type(), - "gather(): self and output must have the same scalar type"); - TORCH_CHECK(dim >= 0 && dim < self.dim(), - "gather(): Indexing dim ", dim, " is out of bounds of tensor"); - - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type"); + TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor"); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* indexTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; @@ -36,7 +28,6 @@ MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - MPSShape* input_shape = getMPSShape(self); MPSShape* index_shape = getMPSShape(index); uint32_t num_input_dims = [input_shape count]; @@ -47,24 +38,25 @@ bool needSlice = false; for (const auto i : c10::irange(num_input_dims)) { - TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") - if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) + TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], + "Index dim must not exceed input dim except at gathering axis") + if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) needSlice = true; } auto input_type = getMPSDataType(self); auto output_type = getMPSDataType(output); - if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) { + if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) { input_type = MPSDataTypeInt8; } - if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) { + if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) { output_type = MPSDataTypeInt8; } string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -76,10 +68,10 @@ MPSGraphTensor* getInput = inputTensor; // Slice into the input tensor IF NEEDED - if(needSlice) { - NSMutableArray *starts = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *ends = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *strides = [NSMutableArray arrayWithCapacity:num_input_dims]; + if (needSlice) { + NSMutableArray* starts = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray* ends = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray* strides = [NSMutableArray arrayWithCapacity:num_input_dims]; for (const auto i : c10::irange(num_input_dims)) { // All strides are 1 @@ -89,23 +81,19 @@ ends[i] = (i != dim) ? index_shape[i] : input_shape[i]; } - getInput = [mpsGraph sliceTensor:inputTensor - starts:starts - ends:ends - strides:strides - name:nil]; + getInput = [mpsGraph sliceTensor:inputTensor starts:starts ends:ends strides:strides name:nil]; } MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor toType:MPSDataTypeInt32 - name:(NSString * _Nonnull)nil]; + name:(NSString* _Nonnull)nil]; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wobjc-method-access" - MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis: (NSInteger) dim - withUpdatesTensor: getInput - indicesTensor: castIndexTensor - name: nil]; + MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis:(NSInteger)dim + withUpdatesTensor:getInput + indicesTensor:castIndexTensor + name:nil]; #pragma clang diagnostic pop newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->indexTensor_ = indexTensor; @@ -113,7 +101,7 @@ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type); @@ -124,23 +112,20 @@ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } } -void scatter_mps_general -(const Tensor& self_arg, - int64_t dim, - const Tensor& index, - const Tensor& src, - const Tensor& output, - string func_name, - const c10::string_view reduce) -{ +void scatter_mps_general(const Tensor& self_arg, + int64_t dim, + const Tensor& index, + const Tensor& src, + const Tensor& output, + string func_name, + const c10::string_view reduce) { using namespace mps; if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) { @@ -151,12 +136,10 @@ TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(), "scatter(): self, src and output must have the same scalar type"); - TORCH_CHECK(dim >= 0 && dim < self.dim(), - "scatter(): Indexing dim ", dim, " is out of bounds of tensor"); + TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor"); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* indexTensor_ = nil; MPSGraphTensor* srcTensor_ = nil; @@ -166,7 +149,6 @@ MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - MPSShape* input_shape = getMPSShape(self); MPSShape* index_shape = getMPSShape(index); MPSShape* src_shape = getMPSShape(src); @@ -174,7 +156,8 @@ uint32_t num_index_dims = [index_shape count]; uint32_t num_src_dims = [src_shape count]; - TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, "Input, index and src must have same rank") + TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, + "Input, index and src must have same rank") // Do we need to slice into the src tensor? bool needSlice = false; @@ -182,11 +165,13 @@ bool needsCast = false; for (const auto i : c10::irange(num_input_dims)) { - TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") - TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") - if([index_shape[i] intValue] < [src_shape[i] intValue]) + TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], + "Index dim must not exceed input dim except at gathering axis") + TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], + "Index dim must not exceed input dim except at gathering axis") + if ([index_shape[i] intValue] < [src_shape[i] intValue]) needSlice = true; - if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) + if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) inputNeedSlice = true; } TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS") @@ -197,11 +182,12 @@ needsCast = true; } - string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + + std::string(reduce); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -209,7 +195,7 @@ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); - MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src); + MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src); MPSGraphTensor* outputTensor = nil; MPSGraphTensor* castSrcTensor = srcTensor; @@ -229,9 +215,9 @@ // Slice into the src or input tensors IF NEEDED if (needSlice || inputNeedSlice) { - NSMutableArray *starts = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *strides = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *ends_src = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray* starts = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray* strides = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray* ends_src = [NSMutableArray arrayWithCapacity:num_input_dims]; for (const auto i : c10::irange(num_input_dims)) { strides[i] = @1; @@ -240,44 +226,41 @@ scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i]; } if (needSlice) { - slicedSrc = [mpsGraph sliceTensor:castSrcTensor - starts:starts - ends:ends_src - strides:strides - name:nil]; + slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts ends:ends_src strides:strides name:nil]; } if (inputNeedSlice) { slicedInput = [mpsGraph sliceTensor:castInputTensor - starts:starts - ends:scatterInputShape - strides:strides - name:nil]; + starts:starts + ends:scatterInputShape + strides:strides + name:nil]; } } MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet; - if(reduce == "sum" || reduce == "add") + if (reduce == "sum" || reduce == "add") scatter_mode = MPSGraphScatterModeAdd; - else if(reduce == "prod" || reduce == "multiply") + else if (reduce == "prod" || reduce == "multiply") scatter_mode = MPSGraphScatterModeMul; - else if(reduce == "amax") + else if (reduce == "amax") scatter_mode = MPSGraphScatterModeMax; - else if(reduce == "amin") + else if (reduce == "amin") scatter_mode = MPSGraphScatterModeMin; - // Scatter this into the input with set mode + // Scatter this into the input with set mode #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wobjc-method-access" - MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim - withDataTensor: slicedInput - updatesTensor: slicedSrc - indicesTensor: castIndexTensor - mode: scatter_mode - name: nil]; + MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis:(NSInteger)dim + withDataTensor:slicedInput + updatesTensor:slicedSrc + indicesTensor:castIndexTensor + mode:scatter_mode + name:nil]; #pragma clang diagnostic pop - if(inputNeedSlice) { + if (inputNeedSlice) { // Make an array of scatter indices tensors - NSMutableArray* indicesTensors = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray* indicesTensors = + [NSMutableArray arrayWithCapacity:num_input_dims]; // 1. Concatenate the coord tensors // 2. Flatten the values @@ -289,18 +272,18 @@ shape_data[i] = {[scatterInputShape[i] intValue]}; } - MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)] - shape:@[[NSNumber numberWithUnsignedInt:num_input_dims]] - dataType:MPSDataTypeInt32]; + MPSGraphTensor* scatterInputShapeTensor = + [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)] + shape:@[ [NSNumber numberWithUnsignedInt:num_input_dims] ] + dataType:MPSDataTypeInt32]; for (const auto i : c10::irange(num_input_dims)) { - MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i - dataType:MPSDataTypeInt32]; - MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor: axisTensor - withShapeTensor: scatterInputShapeTensor - name: nil]; + MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i dataType:MPSDataTypeInt32]; + MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor:axisTensor + withShapeTensor:scatterInputShapeTensor + name:nil]; scatter_currentIndexTensor = [mpsGraph reshapeTensor:scatter_currentIndexTensor - withShape:@[@-1, @1] + withShape:@[ @-1, @1 ] name:nil]; indicesTensors[i] = scatter_currentIndexTensor; } @@ -309,9 +292,7 @@ dimension:(NSInteger)1 name:nil]; - MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor - withShape:@[@-1] - name:nil]; + MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor withShape:@[ @-1 ] name:nil]; outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor updatesTensor:flatValuesTensor @@ -325,11 +306,12 @@ newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->srcTensor_ = srcTensor; newCachedGraph->indexTensor_ = indexTensor; - newCachedGraph->outputTensor_ = needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor; + newCachedGraph->outputTensor_ = + needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape); @@ -342,41 +324,24 @@ srcPlaceholder.getMPSGraphTensor() : srcPlaceholder.getMPSGraphTensorData(), indexPlaceholder.getMPSGraphTensor() : indexPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } } TORCH_IMPL_FUNC(scatter_src_out_mps) -(const Tensor& self, - int64_t dim, - const Tensor& index, - const Tensor& src, - const Tensor& output) { - +(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) { scatter_mps_general(self, dim, index, src, output, "scatter_src_out_mps", "set"); - } TORCH_IMPL_FUNC(scatter_value_out_mps) -(const Tensor& self, - int64_t dim, - const Tensor& index, - const Scalar& value, - const Tensor& output) { - - Tensor src = at::native::empty_mps(index.sizes(), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - self.suggest_memory_format()); +(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& output) { + Tensor src = at::native::empty_mps( + index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format()); src.fill_(value); scatter_mps_general(self, dim, index, const_cast(src), output, "scatter_value_out_mps", "set"); - } TORCH_IMPL_FUNC(scatter_reduce_out_mps) @@ -386,9 +351,7 @@ const Tensor& src, const c10::string_view reduce, const Tensor& output) { - scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce); - } TORCH_IMPL_FUNC(scatter_value_reduce_out_mps) @@ -398,25 +361,14 @@ const Scalar& value, const c10::string_view reduce, const Tensor& output) { - - Tensor src = at::native::empty_mps(index.sizes(), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - self.suggest_memory_format()); + Tensor src = at::native::empty_mps( + index.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format()); src.fill_(value); scatter_mps_general(self, dim, index, const_cast(src), output, "scatter_value_reduce_out_mps", reduce); - } TORCH_IMPL_FUNC(scatter_add_mps_out) -(const Tensor& self, - int64_t dim, - const Tensor& index, - const Tensor& src, - const Tensor& output) { - +(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) { scatter_mps_general(self, dim, index, src, output, "scatter_add_mps_out", "add"); } diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 0b359fa8f6cf2a..9a3d37338e28e1 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -2,10 +2,10 @@ #include #include -#include #include -#include +#include #include +#include namespace at::native { @@ -27,21 +27,12 @@ // topk TORCH_IMPL_FUNC(topk_out_mps) - (const Tensor& self, - int64_t k, - int64_t dim_, - bool largest, - bool sorted, - const Tensor& values, - const Tensor& indices) -{ +(const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted, const Tensor& values, const Tensor& indices) { using namespace mps; int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); - TORCH_CHECK( - k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), - "selected index k out of range"); + TORCH_CHECK(k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range"); - if (!is_macos_13_or_newer() && (k>16)) { + if (!is_macos_13_or_newer() && (k > 16)) { TORCH_WARN_ONCE("torch.topk support for k>16 by MPS on MacOS 13+, please upgrade"); Tensor cpu_indices = indices.clone().to("cpu"); Tensor cpu_values = values.clone().to("cpu"); @@ -52,31 +43,29 @@ } if (self.dim() == 0 && self.numel() == 1) { - values.copy_(self); - indices.zero_(); - return; + values.copy_(self); + indices.zero_(); + return; } // Handle empty tensors - if (self.numel() == 0) - { - values.copy_(self); - indices.copy_(values.toType(at::ScalarType::Long)); - return; + if (self.numel() == 0) { + values.copy_(self); + indices.copy_(values.toType(at::ScalarType::Long)); + return; } // Handle k == 0 case. Needed because MPSGraph does not support k == 0. - if (k == 0) - { - const auto out_shape = getTopK0Shape(self.sizes(), dim); - values.resize_(out_shape); - indices.copy_(values.toType(at::ScalarType::Long)); - return; + if (k == 0) { + const auto out_shape = getTopK0Shape(self.sizes(), dim); + values.resize_(out_shape); + indices.copy_(values.toType(at::ScalarType::Long)); + return; } MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil; + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -85,154 +74,126 @@ // Input as placeholders MPSShape* input_shape = getMPSShape(self); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = string("topk:") + [ns_shape_key UTF8String] + ":" + - getMPSTypeString(self) + - ":k" + to_string(k) + ":dim" + to_string(dim_) + - ":largest" + to_string(largest); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) + + ":dim" + to_string(dim_) + ":largest" + to_string(largest); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); - - if (is_macos_13_or_newer()) { - MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor; - MPSDataType dataType = getMPSDataType(self); - // #issue 104398441 sortWithTensor and argsortWithTensor - if (dataType != MPSDataTypeInt32 && - dataType != MPSDataTypeFloat32 && - dataType != MPSDataTypeFloat16) { - dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; - castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor - toType:dataType - name:@"castInputTensor"]; - } - MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor - axis:(NSUInteger)dim - descending:largest + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); + + if (is_macos_13_or_newer()) { + MPSGraphTensor* castInputTensor = newCachedGraph->selfTensor; + MPSDataType dataType = getMPSDataType(self); + // #issue 104398441 sortWithTensor and argsortWithTensor + if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + castInputTensor = [mpsGraph castTensor:newCachedGraph->selfTensor + toType:dataType + name:@"castInputTensor"]; + } + MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor + axis:(NSUInteger)dim + descending:largest + name:nil]; + sortedTensor = [mpsGraph sliceTensor:sortedTensor + dimension:(NSUInteger)dim + start:((NSUInteger)0)length:k + name:nil]; + MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor + axis:(NSInteger)dim + descending:largest + name:@"argmax_out"]; + argSortedTensor = [mpsGraph sliceTensor:argSortedTensor + dimension:dim + start:((NSUInteger)0)length:k + name:nil]; + newCachedGraph->valuesTensor = sortedTensor; + newCachedGraph->indicesTensor = argSortedTensor; + + } else { + if ((dim_ != -1 && dim_ != self.dim() - 1) && (!largest)) { + // transpose and negate + MPSGraphTensor* transposedInput = [mpsGraph transposeTensor:newCachedGraph->selfTensor + dimension:(NSUInteger)self.dim() - 1 + withDimension:(NSUInteger)dim_ + name:nil]; + MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil]; + MPSGraphTensor* negatedTransposedInput = [mpsGraph negativeWithTensor:identity name:nil]; + NSArray* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedTransposedInput + k:((NSUInteger)k)name:nil]; + MPSGraphTensor* valuesNegatedTransposed = outputMPSGraphTensors[0]; + MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1]; + MPSGraphTensor* valuesNegated = [mpsGraph transposeTensor:valuesNegatedTransposed + dimension:(NSUInteger)self.dim() - 1 + withDimension:(NSUInteger)dim_ + name:nil]; + newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil]; + newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed + dimension:(NSUInteger)self.dim() - 1 + withDimension:(NSUInteger)dim_ + name:nil]; + } else if (dim_ != -1 && dim_ != self.dim() - 1) { + MPSGraphTensor* transposedInput = [mpsGraph transposeTensor:newCachedGraph->selfTensor + dimension:(NSUInteger)self.dim() - 1 + withDimension:(NSUInteger)dim_ + name:nil]; + MPSGraphTensor* identity = [mpsGraph identityWithTensor:transposedInput name:nil]; + NSArray* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:identity + k:((NSUInteger)k)name:nil]; + MPSGraphTensor* valuesTransposed = outputMPSGraphTensors[0]; + MPSGraphTensor* indicesTransposed = outputMPSGraphTensors[1]; + newCachedGraph->valuesTensor = [mpsGraph transposeTensor:valuesTransposed + dimension:(NSUInteger)self.dim() - 1 + withDimension:(NSUInteger)dim_ name:nil]; - sortedTensor = [mpsGraph sliceTensor:sortedTensor - dimension:(NSUInteger)dim - start:((NSUInteger) 0) - length:k - name:nil]; - MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor - axis:(NSInteger)dim - descending:largest - name:@"argmax_out"]; - argSortedTensor = [mpsGraph sliceTensor:argSortedTensor - dimension:dim - start:((NSUInteger) 0) - length:k - name:nil]; - newCachedGraph->valuesTensor = sortedTensor; - newCachedGraph->indicesTensor = argSortedTensor; - + newCachedGraph->indicesTensor = [mpsGraph transposeTensor:indicesTransposed + dimension:(NSUInteger)self.dim() - 1 + withDimension:(NSUInteger)dim_ + name:nil]; + } else if (!largest) { + // only negate + MPSGraphTensor* negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor name:nil]; + NSArray* outputMPSGraphTensors = [mpsGraph topKWithSourceTensor:negatedInput + k:((NSUInteger)k)name:nil]; + MPSGraphTensor* valuesNegated = outputMPSGraphTensors[0]; + newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated name:nil]; + newCachedGraph->indicesTensor = outputMPSGraphTensors[1]; } else { - if ((dim_ != -1 && dim_ != self.dim() - 1) && (!largest)) { - // transpose and negate - MPSGraphTensor *transposedInput = [mpsGraph transposeTensor: newCachedGraph->selfTensor - dimension: (NSUInteger)self.dim()-1 - withDimension: (NSUInteger)dim_ - name: nil]; - MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput - name: nil]; - MPSGraphTensor * negatedTransposedInput = [mpsGraph negativeWithTensor:identity - name: nil]; - NSArray * outputMPSGraphTensors = [mpsGraph - topKWithSourceTensor:negatedTransposedInput - k:((NSUInteger) k) - name:nil]; - MPSGraphTensor *valuesNegatedTransposed = outputMPSGraphTensors[0]; - MPSGraphTensor *indicesTransposed = outputMPSGraphTensors[1]; - MPSGraphTensor *valuesNegated = [mpsGraph transposeTensor: valuesNegatedTransposed - dimension: (NSUInteger)self.dim()-1 - withDimension: (NSUInteger)dim_ - name: nil]; - newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated - name: nil]; - newCachedGraph->indicesTensor = [mpsGraph transposeTensor: indicesTransposed - dimension: (NSUInteger)self.dim()-1 - withDimension: (NSUInteger)dim_ - name: nil]; - } else if (dim_ != -1 && dim_ != self.dim() - 1) { - MPSGraphTensor *transposedInput = [mpsGraph transposeTensor: newCachedGraph->selfTensor - dimension: (NSUInteger)self.dim()-1 - withDimension: (NSUInteger)dim_ - name: nil]; - MPSGraphTensor * identity = [mpsGraph identityWithTensor: transposedInput - name: nil]; - NSArray * outputMPSGraphTensors = [mpsGraph - topKWithSourceTensor:identity - k:((NSUInteger) k) - name:nil]; - MPSGraphTensor *valuesTransposed = outputMPSGraphTensors[0]; - MPSGraphTensor *indicesTransposed = outputMPSGraphTensors[1]; - newCachedGraph->valuesTensor = [mpsGraph transposeTensor:valuesTransposed - dimension: (NSUInteger)self.dim()-1 - withDimension: (NSUInteger)dim_ - name: nil]; - newCachedGraph->indicesTensor = [mpsGraph transposeTensor: indicesTransposed - dimension: (NSUInteger)self.dim()-1 - withDimension: (NSUInteger)dim_ - name: nil]; - } else if (!largest) { - // only negate - MPSGraphTensor *negatedInput = [mpsGraph negativeWithTensor:newCachedGraph->selfTensor - name: nil]; - NSArray * outputMPSGraphTensors = [mpsGraph - topKWithSourceTensor:negatedInput - k:((NSUInteger) k) - name:nil]; - MPSGraphTensor *valuesNegated = outputMPSGraphTensors[0]; - newCachedGraph->valuesTensor = [mpsGraph negativeWithTensor:valuesNegated - name: nil]; - newCachedGraph->indicesTensor = outputMPSGraphTensors[1]; - } else { - NSArray * outputMPSGraphTensors = [mpsGraph - topKWithSourceTensor:newCachedGraph->selfTensor - k:((NSUInteger) k) - name:nil]; - newCachedGraph->valuesTensor = outputMPSGraphTensors[0]; - newCachedGraph->indicesTensor = outputMPSGraphTensors[1]; - } + NSArray* outputMPSGraphTensors = + [mpsGraph topKWithSourceTensor:newCachedGraph->selfTensor k:((NSUInteger)k)name:nil]; + newCachedGraph->valuesTensor = outputMPSGraphTensors[0]; + newCachedGraph->indicesTensor = outputMPSGraphTensors[1]; } + } } return newCachedGraph; })); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self); + Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self); // Outputs as placeholders Placeholder valuesPlaceholder = Placeholder(cachedGraph->valuesTensor, values); Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices); // Create dictionary of inputs and outputs - NSDictionary* feeds = nil; - feeds = @{ - inputPlaceholder.getMPSGraphTensor() : - inputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = nil; + feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()}; NSDictionary* results = @{ - valuesPlaceholder.getMPSGraphTensor() : - valuesPlaceholder.getMPSGraphTensorData(), - indicesPlaceholder.getMPSGraphTensor() : - indicesPlaceholder.getMPSGraphTensorData() + valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(), + indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() }; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } -void check_shape_except_dim(const Tensor &first, const Tensor &second, - int dimension, int index) -{ +void check_shape_except_dim(const Tensor& first, const Tensor& second, int dimension, int index) { int first_dims = first.dim(); int second_dims = second.dim(); - TORCH_CHECK(first_dims == second_dims, - "Tensors must have same number of dimensions: got ", first_dims, - " and ", second_dims); + TORCH_CHECK( + first_dims == second_dims, "Tensors must have same number of dimensions: got ", first_dims, " and ", second_dims); for (int dim = 0; dim < first_dims; dim++) { if (dim == dimension) { continue; @@ -240,23 +201,27 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, int64_t first_dim_size = at::native::size(first, dim); int64_t second_dim_size = at::native::size(second, dim); TORCH_CHECK(first_dim_size == second_dim_size, - "Sizes of tensors must match except in dimension ", dim, ". Got ", - static_cast(first_dim_size), " and ", - static_cast(second_dim_size), " (The offending index is ", - index, ")"); + "Sizes of tensors must match except in dimension ", + dim, + ". Got ", + static_cast(first_dim_size), + " and ", + static_cast(second_dim_size), + " (The offending index is ", + index, + ")"); } } TORCH_IMPL_FUNC(cat_out_mps) - (const ITensorListRef& inputs, - int64_t dimension, - int64_t valid, - bool all_contiguous, - bool all_same_dtype, - bool all_same_sizes_and_stride, - MemoryFormat memory_format, - const Tensor& out) { - +(const ITensorListRef& inputs, + int64_t dimension, + int64_t valid, + bool all_contiguous, + bool all_same_dtype, + bool all_same_sizes_and_stride, + MemoryFormat memory_format, + const Tensor& out) { using namespace mps; if (out.numel() == 0) { @@ -270,14 +235,16 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated"); auto lap = at::get_overlap_status(out, t); TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full, - "torch.cat(): unsupported operation: the input tensors cannot refer to any " - "of the output memory locations. Found overlap in input tensor ", idx); + "torch.cat(): unsupported operation: the input tensors cannot refer to any " + "of the output memory locations. Found overlap in input tensor ", + idx); idx++; } // Check for type promotion TORCH_CHECK(canCast(out_dtype, out.scalar_type()), - "torch.cat(): input types can't be cast to the desired output type ", out.scalar_type()); - TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size()); + "torch.cat(): input types can't be cast to the desired output type ", + out.scalar_type()); + TORCH_CHECK(inputs.size() > 0, "torch.cat(): invalid number of inputs ", inputs.size()); dimension = legacy_cat_wrap_dim(dimension, materialized_inputs); TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension); @@ -288,9 +255,7 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, // this behavior for backwards compatibility, but only for this specific size // (i.e. other empty sizes are not skipped). // FIXME: warn if this is the case - auto should_skip = [](const Tensor& t) { - return t.dim() == 1 && at::native::size(t, 0) == 0; - }; + auto should_skip = [](const Tensor& t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; at::assert_no_internal_overlap(out); Tensor notSkippedTensor; @@ -317,11 +282,15 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, for (const Tensor& t : inputs) { TORCH_CHECK(t.device() == notSkippedTensor.device(), "torch.cat(): all input tensors must be on the same device. Received ", - t.device(), " and ", notSkippedTensor.device()); + t.device(), + " and ", + notSkippedTensor.device()); } TORCH_CHECK(out.device() == notSkippedTensor.device(), "torch.cat(): all input tensors and out must be on the same device, but inputs are on ", - notSkippedTensor.device(), " and out is on ", out.device()); + notSkippedTensor.device(), + " and out is on ", + out.device()); // TODO: For better performance by eliminating input tensor gathering and post transpose, // TODO: it is better to keep the out tensor's memory format. @@ -354,23 +323,23 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, } struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} std::vector inputTensors_; MPSGraphTensor* outputTensor_ = nil; }; - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" + - (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); + string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/ true) + + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph *mpsGraph = make_mps_graph(); + MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); auto len_tensor_array = inputs.size() - skipped_tensor_indices.size(); @@ -383,7 +352,8 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, if (tensor.scalar_type() == kBool) { scalar_type = MPSDataTypeInt8; } - newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous)); + newCachedGraph->inputTensors_[idx] = + mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, MemoryFormat::Contiguous)); if (tensor.scalar_type() != out_dtype) { castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx] toType:getMPSDataType(out_dtype) @@ -393,15 +363,12 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, } } - auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() - count:len_tensor_array]; + auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array]; MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray dimension:dimension // Maybe convert this from int64_t -> int32 name:nil]; if (getMPSDataType(out_dtype) == MPSDataTypeBool) { - outputTensor = [mpsGraph castTensor:outputTensor - toType:MPSDataTypeBool - name:@"outputTensor"]; + outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"]; } newCachedGraph->outputTensor_ = outputTensor; } @@ -418,9 +385,11 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, if (tensor.scalar_type() == kBool) { scalar_type = MPSDataTypeInt8; } - inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, + inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], + tensor, getMPSShape(tensor, MemoryFormat::Contiguous), - /*gatherTensorData*/true, scalar_type); + /*gatherTensorData*/ true, + scalar_type); t_idx++; } i++; @@ -430,16 +399,15 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, if (!is_macos_13_or_newer() && out.scalar_type() == kBool) { outputDataType = MPSDataTypeInt8; } - Placeholder outputPlaceholder = Placeholder( - cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; for (auto& inputPlaceholder : inputPlaceholders) { feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); } - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } diff --git a/aten/src/ATen/native/mps/operations/SoftMax.mm b/aten/src/ATen/native/mps/operations/SoftMax.mm index 2ee70e3d0c910b..d5a0e19ee63dfd 100644 --- a/aten/src/ATen/native/mps/operations/SoftMax.mm +++ b/aten/src/ATen/native/mps/operations/SoftMax.mm @@ -16,30 +16,26 @@ namespace at::native { void get_shapes(MPSShape* input_shape_readonly, - NSMutableArray* &input_shape, - int num_input_dims, c10::MemoryFormat memory_format) { + NSMutableArray*& input_shape, + int num_input_dims, + c10::MemoryFormat memory_format) { // Modify the shape - if(memory_format == at::MemoryFormat::Contiguous) { - for(int i = 0; i < num_input_dims; i++) + if (memory_format == at::MemoryFormat::Contiguous) { + for (int i = 0; i < num_input_dims; i++) input_shape[i] = input_shape_readonly[i]; - } - else { // ChannelsLast + } else { // ChannelsLast auto num_channels = input_shape_readonly[1]; input_shape[0] = input_shape_readonly[0]; - for(int i = 1; i < num_input_dims-1; i++) - input_shape[i] = input_shape_readonly[i+1]; - input_shape[num_input_dims-1] = num_channels; + for (int i = 1; i < num_input_dims - 1; i++) + input_shape[i] = input_shape_readonly[i + 1]; + input_shape[num_input_dims - 1] = num_channels; } } // Note - Currently only supported for 4D image tensors TORCH_IMPL_FUNC(softmax_mps_out) -(const Tensor& input_, - const int64_t dim, - const bool half_to_float, - const Tensor& output) { - +(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) { TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS"); if (input_.numel() == 0) { @@ -49,25 +45,22 @@ void get_shapes(MPSShape* input_shape_readonly, Tensor input; if (input_.dim() == 0) { input = input_.view(1); - } - else + } else input = input_; int64_t dim_ = maybe_wrap_dim(dim, input.dim()); - TORCH_CHECK( - dim_ >= 0 && dim_ < input.dim(), - "Softmax:dim must be non-negative and less than input dimensions"); + TORCH_CHECK(dim_ >= 0 && dim_ < input.dim(), "Softmax:dim must be non-negative and less than input dimensions"); const auto memory_format = input.suggest_memory_format(); - // TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should match") + // TORCH_CHECK(input.suggest_memory_format() == output.suggest_memory_format(), "Input and output memory format should + // match") using namespace mps; MPSStream* stream = getCurrentMPSStream(); // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; }; @@ -75,20 +68,20 @@ void get_shapes(MPSShape* input_shape_readonly, MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string mem_format_key = get_mem_format_string(memory_format); MPSShape* input_shape_readonly = mps::getMPSShape(input); int num_input_dims = [input_shape_readonly count]; // Check - Channels last implies 4d - TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4, "ChannelsLast implies 4d tensor") + TORCH_CHECK(memory_format != at::MemoryFormat::ChannelsLast || num_input_dims == 4, + "ChannelsLast implies 4d tensor") // Input shape changes based on memory format NSMutableArray* input_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; get_shapes(input_shape_readonly, input_shape, num_input_dims, memory_format); // Change dim - if(memory_format == at::MemoryFormat::ChannelsLast && dim_ > 0) { - switch(dim_) { + if (memory_format == at::MemoryFormat::ChannelsLast && dim_ > 0) { + switch (dim_) { case 1: dim_ = 3; break; @@ -105,13 +98,13 @@ void get_shapes(MPSShape* input_shape_readonly, NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":" - + [ns_shape_key UTF8String] + ":" + std::to_string(dim_); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = "softmax_mps_out:" + mem_format_key + ":" + getMPSTypeString(input) + ":" + [ns_shape_key UTF8String] + + ":" + std::to_string(dim_); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -120,28 +113,20 @@ void get_shapes(MPSShape* input_shape_readonly, MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); // passing selector of softMaxWithTensor on the mpsGraph object - MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor - axis:(NSInteger)dim_ - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph softMaxWithTensor:inputTensor axis:(NSInteger)dim_ name:nil]; // Output needs to be contiguous format - if(memory_format == at::MemoryFormat::ChannelsLast) { + if (memory_format == at::MemoryFormat::ChannelsLast) { auto N = input_shape[0]; auto H = input_shape[1]; auto W = input_shape[2]; auto C = input_shape[3]; outputTensor = [mpsGraph reshapeTensor:outputTensor - withShape:@[N, ([NSNumber numberWithInt:[H intValue]* [W intValue]]), C] + withShape:@[ N, ([NSNumber numberWithInt:[H intValue] * [W intValue]]), C ] name:nil]; - outputTensor = [mpsGraph transposeTensor:outputTensor - dimension:1 - withDimension:2 - name:nil]; - outputTensor = [mpsGraph reshapeTensor:outputTensor - withShape:@[N, C, H, W] - name:nil]; - + outputTensor = [mpsGraph transposeTensor:outputTensor dimension:1 withDimension:2 name:nil]; + outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:@[ N, C, H, W ] name:nil]; } newCachedGraph->inputTensor_ = inputTensor; @@ -149,32 +134,24 @@ void get_shapes(MPSShape* input_shape_readonly, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, input_shape); // This must be the Contiguous shape Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()}; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } TORCH_IMPL_FUNC(softmax_backward_mps_out) -(const Tensor& grad_, - const Tensor& output_, - int64_t dim, - ScalarType input_dtype, - const Tensor& grad_input) { - +(const Tensor& grad_, const Tensor& output_, int64_t dim, ScalarType input_dtype, const Tensor& grad_input) { if (output_.numel() == 0) { return; } @@ -182,29 +159,24 @@ void get_shapes(MPSShape* input_shape_readonly, Tensor grad; if (grad_.dim() == 0) { grad = grad_.view(1); - } - else + } else grad = grad_; Tensor output; if (output_.dim() == 0) { output = output_.view(1); - } - else + } else output = output_; int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); - TORCH_CHECK( - dim_ >= 0 && dim_ < grad.dim(), - "Grad:dim must be non-negative and less than input dimensions"); + TORCH_CHECK(dim_ >= 0 && dim_ < grad.dim(), "Grad:dim must be non-negative and less than input dimensions"); using namespace mps; MPSStream* stream = getCurrentMPSStream(); // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* softmaxTensor_ = nil; MPSGraphTensor* gradOutputTensor_ = nil; MPSGraphTensor* gradInputTensor_ = nil; @@ -213,17 +185,16 @@ void get_shapes(MPSShape* input_shape_readonly, MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - MPSShape* grad_shape = mps::getMPSShape(grad); NSString* ns_shape_key = [[grad_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" - + [ns_shape_key UTF8String] + ":" + std::to_string(dim_); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" + [ns_shape_key UTF8String] + ":" + + std::to_string(dim_); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -235,9 +206,7 @@ void get_shapes(MPSShape* input_shape_readonly, MPSGraphTensor* mulTensor = [mpsGraph multiplicationWithPrimaryTensor:softmaxTensor secondaryTensor:gradOutputTensor name:nil]; - MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor - axis:(NSInteger)dim_ - name:nil]; + MPSGraphTensor* mulSumTensor = [mpsGraph reductionSumWithTensor:mulTensor axis:(NSInteger)dim_ name:nil]; MPSGraphTensor* gradSubTensor = [mpsGraph subtractionWithPrimaryTensor:gradOutputTensor secondaryTensor:mulSumTensor name:nil]; @@ -251,7 +220,7 @@ void get_shapes(MPSShape* input_shape_readonly, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder softmaxPlaceholder = Placeholder(cachedGraph->softmaxTensor_, output, grad_shape); @@ -262,12 +231,10 @@ void get_shapes(MPSShape* input_shape_readonly, softmaxPlaceholder.getMPSGraphTensor() : softmaxPlaceholder.getMPSGraphTensorData(), gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } } diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index 402208c92085dd..f4332d85618cbc 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -2,10 +2,10 @@ #include #include -#include #include -#include +#include #include +#include namespace at::native { @@ -42,60 +42,57 @@ MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil; + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *selfTensor = nil, *valuesTensor = nil, *indicesTensor = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { // Input as placeholders MPSShape* input_shape = getMPSShape(self); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + - ":dim" + to_string(dim) + ":descending" + to_string(descending); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) + + ":descending" + to_string(descending); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); - MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus); - MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor - axis:(NSInteger)dim - descending:(BOOL)descending - name:@"sort_out"]; - if ([sortedTensor dataType] != getMPSDataType(values)) { - sortedTensor = castMPSTensor(mpsGraph, sortedTensor, values.scalar_type()); - } - MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor - axis:(NSInteger)dim - descending:(BOOL)descending - name:@"argsort_out"]; - if ([argSortedTensor dataType] != getMPSDataType(indices)) { - argSortedTensor = castMPSTensor(mpsGraph, argSortedTensor, indices.scalar_type()); - } - newCachedGraph->valuesTensor = sortedTensor; - newCachedGraph->indicesTensor = argSortedTensor; + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus); + MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor + axis:(NSInteger)dim + descending:(BOOL)descending + name:@"sort_out"]; + if ([sortedTensor dataType] != getMPSDataType(values)) { + sortedTensor = castMPSTensor(mpsGraph, sortedTensor, values.scalar_type()); + } + MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor + axis:(NSInteger)dim + descending:(BOOL)descending + name:@"argsort_out"]; + if ([argSortedTensor dataType] != getMPSDataType(indices)) { + argSortedTensor = castMPSTensor(mpsGraph, argSortedTensor, indices.scalar_type()); + } + newCachedGraph->valuesTensor = sortedTensor; + newCachedGraph->indicesTensor = argSortedTensor; } return newCachedGraph; })); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self); + Placeholder inputPlaceholder = Placeholder(cachedGraph->selfTensor, self); // Outputs as placeholders Placeholder valuesPlaceholder = Placeholder(cachedGraph->valuesTensor, values); Placeholder indicesPlaceholder = Placeholder(cachedGraph->indicesTensor, indices); // Create dictionary of inputs and outputs - NSDictionary* feeds = nil; - feeds = @{ inputPlaceholder.getMPSGraphTensor() : - inputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = nil; + feeds = @{inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()}; NSDictionary* results = @{ - valuesPlaceholder.getMPSGraphTensor() : - valuesPlaceholder.getMPSGraphTensorData(), - indicesPlaceholder.getMPSGraphTensor() : - indicesPlaceholder.getMPSGraphTensorData() + valuesPlaceholder.getMPSGraphTensor() : valuesPlaceholder.getMPSGraphTensorData(), + indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() }; runMPSGraph(stream, cachedGraph->graph(), feeds, results); diff --git a/aten/src/ATen/native/mps/operations/SummaryOps.mm b/aten/src/ATen/native/mps/operations/SummaryOps.mm index 2e1b27741c8597..5a2e7c86c703b1 100644 --- a/aten/src/ATen/native/mps/operations/SummaryOps.mm +++ b/aten/src/ATen/native/mps/operations/SummaryOps.mm @@ -4,14 +4,11 @@ namespace at::native { -Tensor& bincount_mps_impl(const Tensor& self, - const Tensor& weights, - Tensor& output) { +Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tensor& output) { using namespace mps; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* weightsTensor_ = nil; MPSGraphTensor* scatterDataTensor_ = nil; @@ -24,42 +21,37 @@ @autoreleasepool { string key = "bincount_mps_impl" + getTensorsStringKey({self, weights}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { // Initialize graph MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type())); + MPSGraphTensor* scatterDataTensor = + mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type())); - MPSGraphTensor *updatesTensor = nil; + MPSGraphTensor* updatesTensor = nil; if (has_weights) { updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, weights); - } - else { - updatesTensor = [mpsGraph constantWithScalar:1.0f - shape:getMPSShape(self) - dataType:getMPSDataType(output)]; + } else { + updatesTensor = [mpsGraph constantWithScalar:1.0f shape:getMPSShape(self) dataType:getMPSDataType(output)]; } - MPSGraphTensor *castedInputTensor = inputTensor; + MPSGraphTensor* castedInputTensor = inputTensor; if (self.scalar_type() == kByte) { - castedInputTensor = [mpsGraph castTensor:inputTensor - toType:MPSDataTypeInt32 - name:@"castInputTensor"]; + castedInputTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"castInputTensor"]; } - MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor + MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor updatesTensor:updatesTensor indicesTensor:castedInputTensor - axis:0 - mode:MPSGraphScatterModeAdd - name:nil]; + axis:0 + mode:MPSGraphScatterModeAdd + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -70,7 +62,7 @@ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation @@ -80,17 +72,16 @@ Placeholder weightsPlaceholder = Placeholder(); // Create dictionary of inputs/feeds and outputs/results - NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[scatterPlaceholder.getMPSGraphTensor()] = scatterPlaceholder.getMPSGraphTensorData(); - if(has_weights) { + if (has_weights) { weightsPlaceholder = Placeholder(cachedGraph->weightsTensor_, weights); feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData(); } - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; // Run the graph runMPSGraph(stream, cachedGraph->graph(), feeds, results); @@ -108,43 +99,32 @@ Tensor _bincount_mps(const Tensor& self, const c10::optional& weights_op TORCH_CHECK(minlength >= 0, "minlength should be >= 0"); if (self.dim() == 1 && self.numel() == 0) { - return at::zeros( - {minlength}, - kLong, - c10::nullopt /* layout */, - kMPS, - c10::nullopt /* pin_memory */); + return at::zeros({minlength}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */); } - TORCH_CHECK(self.dim() == 1 && self.min().item() >= 0, "bincount only supports 1-d non-negative integral inputs."); + TORCH_CHECK(self.dim() == 1 && self.min().item() >= 0, + "bincount only supports 1-d non-negative integral inputs."); bool has_weights = weights.defined(); - TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))), "weights should be 1-d and have the same length as input"); + TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))), + "weights should be 1-d and have the same length as input"); const int64_t nbins = std::max(self.max().item() + 1L, minlength); Tensor output; Tensor weights_ = weights; if (has_weights) { - if(weights.scalar_type() != ScalarType::Float && - weights.scalar_type() != ScalarType::Int && - weights.scalar_type() != ScalarType::Half) { - // Scatter doesn't work for int8/int16 dtypes - weights_ = weights.to(kInt); + if (weights.scalar_type() != ScalarType::Float && weights.scalar_type() != ScalarType::Int && + weights.scalar_type() != ScalarType::Half) { + // Scatter doesn't work for int8/int16 dtypes + weights_ = weights.to(kInt); } - output = at::zeros( - {nbins}, - optTypeMetaToScalarType(weights_.options().dtype_opt()), - weights_.options().layout_opt(), - weights_.options().device_opt(), - weights_.options().pinned_memory_opt()); - } - else { - output = at::zeros( - {nbins}, - kLong, - c10::nullopt /* layout */, - kMPS, - c10::nullopt /* pin_memory */); + output = at::zeros({nbins}, + optTypeMetaToScalarType(weights_.options().dtype_opt()), + weights_.options().layout_opt(), + weights_.options().device_opt(), + weights_.options().pinned_memory_opt()); + } else { + output = at::zeros({nbins}, kLong, c10::nullopt /* layout */, kMPS, c10::nullopt /* pin_memory */); } return bincount_mps_impl(self, weights_, output); diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 4f8def1cbb7775..e08b7145c0237b 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -1,303 +1,281 @@ // Copyright © 2022 Apple Inc. -#include -#include #include +#include +#include namespace at::native { namespace mps { -struct CachedGraph : public MPSCachedGraph -{ - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor = nil, *outputTensor = nil; - MPSGraphTensor *minTensor = nil, *maxTensor = nil; +struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor = nil, *outputTensor = nil; + MPSGraphTensor *minTensor = nil, *maxTensor = nil; }; -void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor) -{ - MPSGraph *mpsGraph = cachedGraph->graph(); - - cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); - - if (cachedGraph->minTensor && cachedGraph->maxTensor) { - cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor - minValueTensor:cachedGraph->minTensor - maxValueTensor:cachedGraph->maxTensor - name:nil]; - } else if (cachedGraph->maxTensor) { - cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor - secondaryTensor:cachedGraph->maxTensor - name:nil]; - } else if (cachedGraph->minTensor) { - cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor - secondaryTensor:cachedGraph->minTensor - name:nil]; - } +void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + + cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); + + if (cachedGraph->minTensor && cachedGraph->maxTensor) { + cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor + minValueTensor:cachedGraph->minTensor + maxValueTensor:cachedGraph->maxTensor + name:nil]; + } else if (cachedGraph->maxTensor) { + cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor + secondaryTensor:cachedGraph->maxTensor + name:nil]; + } else if (cachedGraph->minTensor) { + cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor + secondaryTensor:cachedGraph->minTensor + name:nil]; + } } -void check_min_max_dims(const OptionalTensorRef clamp_opt, - const Tensor& input_t, - string op_name) { - - if(!clamp_opt->is_same_size(input_t)) { - - auto num_clamp_dims = clamp_opt->dim(); - auto num_input_dims = input_t.dim(); - - auto clamp_shape = clamp_opt->sizes(); - auto input_shape = input_t.sizes(); +void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, string op_name) { + if (!clamp_opt->is_same_size(input_t)) { + auto num_clamp_dims = clamp_opt->dim(); + auto num_input_dims = input_t.dim(); - TORCH_CHECK(num_clamp_dims <= num_input_dims, op_name + ": clamp tensor number of dims must not be greater than that of input tensor") + auto clamp_shape = clamp_opt->sizes(); + auto input_shape = input_t.sizes(); - for(int i = 0; i < num_clamp_dims; i++) - // One of the indices is allowed to be 1; will be handled by broadcast - TORCH_CHECK(clamp_shape[num_clamp_dims-1-i] == input_shape[num_input_dims-1-i] || - clamp_shape[num_clamp_dims-1-i] == 1 || - input_shape[num_input_dims-1-i] == 1, - op_name + ": clamp tensor trailing shape must match input tensor") + TORCH_CHECK(num_clamp_dims <= num_input_dims, + op_name + ": clamp tensor number of dims must not be greater than that of input tensor") - } + for (int i = 0; i < num_clamp_dims; i++) + // One of the indices is allowed to be 1; will be handled by broadcast + TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] || + clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1, + op_name + ": clamp tensor trailing shape must match input tensor") + } } -void fill_new_shape(int64_t num_input_dims, - int64_t num_clamp_dims, - int64_t *new_shape, - IntArrayRef clamp_shape) { - - // Extend the shape with ones to the left - int clamp_idx = 0; - for(int i = 0; i < num_input_dims; i++) { - if(i < num_input_dims - num_clamp_dims) - new_shape[i] = 1; - else { - new_shape[i] = clamp_shape[clamp_idx]; - clamp_idx++; - } +void fill_new_shape(int64_t num_input_dims, int64_t num_clamp_dims, int64_t* new_shape, IntArrayRef clamp_shape) { + // Extend the shape with ones to the left + int clamp_idx = 0; + for (int i = 0; i < num_input_dims; i++) { + if (i < num_input_dims - num_clamp_dims) + new_shape[i] = 1; + else { + new_shape[i] = clamp_shape[clamp_idx]; + clamp_idx++; } + } } void clamp_tensor_out_mps(const Tensor& input_t, const OptionalTensorRef min_opt, const OptionalTensorRef max_opt, const Tensor& output_t, - string op_name) -{ - const bool has_min = (min_opt.has_value() && min_opt->defined()); - const bool has_max = (max_opt.has_value() && max_opt->defined()); + string op_name) { + const bool has_min = (min_opt.has_value() && min_opt->defined()); + const bool has_max = (max_opt.has_value() && max_opt->defined()); - TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both tensors must be defined") - if (has_min) - check_min_max_dims(min_opt, input_t, op_name); + TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both tensors must be defined") + if (has_min) + check_min_max_dims(min_opt, input_t, op_name); - if (has_max) - check_min_max_dims(max_opt, input_t, op_name); + if (has_max) + check_min_max_dims(max_opt, input_t, op_name); - if (output_t.numel() == 0) - return; + if (output_t.numel() == 0) + return; - IntArrayRef new_min_shape; - IntArrayRef new_max_shape; + IntArrayRef new_min_shape; + IntArrayRef new_max_shape; - auto num_min_dims = min_opt->dim(); - auto num_max_dims = max_opt->dim(); - auto num_input_dims = input_t.dim(); + auto num_min_dims = min_opt->dim(); + auto num_max_dims = max_opt->dim(); + auto num_input_dims = input_t.dim(); - std::vector new_min_arr(num_input_dims); - std::vector new_max_arr(num_input_dims); + std::vector new_min_arr(num_input_dims); + std::vector new_max_arr(num_input_dims); - if(has_min && num_min_dims < num_input_dims) { - fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes()); - new_min_shape = IntArrayRef(new_min_arr); - } + if (has_min && num_min_dims < num_input_dims) { + fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes()); + new_min_shape = IntArrayRef(new_min_arr); + } - if(has_max && num_max_dims < num_input_dims) { - fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes()); - new_max_shape = IntArrayRef(new_max_arr); - } + if (has_max && num_max_dims < num_input_dims) { + fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes()); + new_max_shape = IntArrayRef(new_max_arr); + } - Tensor min_opt_tensor; - Tensor max_opt_tensor; + Tensor min_opt_tensor; + Tensor max_opt_tensor; - if(has_min) { - min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt; - } - if(has_max) { - max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt; - } + if (has_min) { + min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt; + } + if (has_max) { + max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt; + } - @autoreleasepool { - // the optional min/max refs could affect how we build the cached graph - - auto tensor_key = has_min ? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor}) - : getTensorsStringKey({input_t, min_opt_tensor})) - : (has_max ? getTensorsStringKey({input_t, max_opt_tensor}) - : getTensorsStringKey({input_t})); - - string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") - + "_tensor" + tensor_key; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if (!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - if (has_min) - newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor); - if (has_max) - newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor); - - clamp_mps_graph(newCachedGraph, input_t); - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } + @autoreleasepool { + // the optional min/max refs could affect how we build the cached graph - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t); + auto tensor_key = has_min + ? (has_max ? getTensorsStringKey({input_t, min_opt_tensor, max_opt_tensor}) + : getTensorsStringKey({input_t, min_opt_tensor})) + : (has_max ? getTensorsStringKey({input_t, max_opt_tensor}) : getTensorsStringKey({input_t})); - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - if (has_min) { - auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor); - feeds[minPlaceholder.getMPSGraphTensor()] = minPlaceholder.getMPSGraphTensorData(); - } - if (has_max) { - auto maxPlaceholder = Placeholder(cachedGraph->maxTensor, max_opt_tensor); - feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData(); + string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") + "_tensor" + tensor_key; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + if (has_min) + newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor); + if (has_max) + newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor); + + clamp_mps_graph(newCachedGraph, input_t); } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t); + auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t); - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + if (has_min) { + auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor); + feeds[minPlaceholder.getMPSGraphTensor()] = minPlaceholder.getMPSGraphTensorData(); } + if (has_max) { + auto maxPlaceholder = Placeholder(cachedGraph->maxTensor, max_opt_tensor); + feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData(); + } + + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } } void clamp_scalar_out_mps(const Tensor& input_t, - const OptionalScalarRef min_opt, - const OptionalScalarRef max_opt, - const Tensor& output_t, - string op_name) -{ - using scalar_t = double; - - const bool has_min = (min_opt.has_value()); - const bool has_max = (max_opt.has_value()); - TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both scalars must be defined") - - scalar_t min_scalar = std::numeric_limits::infinity(); - scalar_t max_scalar = -std::numeric_limits::infinity(); - - if (has_min) - min_scalar = min_opt.get().to(); - if (has_max) - max_scalar = max_opt.get().to(); - - if (output_t.numel() == 0) - return ; - - @autoreleasepool { - // the optional min/max refs could affect how we build the cached graph - string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + (has_max ? ("_max:" + to_string(max_scalar)) : "") - + "_scalar:" + getTensorsStringKey({input_t}); - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if (!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - if (has_min) - newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar - shape:(mps::getMPSShape(input_t)) - dataType:(mps::getMPSScalarType(input_t.scalar_type())) ]; - if (has_max) - newCachedGraph->maxTensor = [mpsGraph constantWithScalar:max_scalar - shape:(mps::getMPSShape(input_t)) - dataType:(mps::getMPSScalarType(input_t.scalar_type())) ]; - - clamp_mps_graph(newCachedGraph, input_t); - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } + const OptionalScalarRef min_opt, + const OptionalScalarRef max_opt, + const Tensor& output_t, + string op_name) { + using scalar_t = double; + + const bool has_min = (min_opt.has_value()); + const bool has_max = (max_opt.has_value()); + TORCH_CHECK(has_min || has_max, op_name + ": either min, max or both scalars must be defined") + + scalar_t min_scalar = std::numeric_limits::infinity(); + scalar_t max_scalar = -std::numeric_limits::infinity(); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor , input_t); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t); + if (has_min) + min_scalar = min_opt.get().to(); + if (has_max) + max_scalar = max_opt.get().to(); + + if (output_t.numel() == 0) + return; + + @autoreleasepool { + // the optional min/max refs could affect how we build the cached graph + string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + + (has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + if (has_min) + newCachedGraph->minTensor = [mpsGraph + constantWithScalar:min_scalar + shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))]; + if (has_max) + newCachedGraph->maxTensor = [mpsGraph + constantWithScalar:max_scalar + shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))]; + + clamp_mps_graph(newCachedGraph, input_t); + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); } + + auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t); + auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t); + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } } } // namespace mps // APIs exposed to at::native scope TORCH_IMPL_FUNC(clamp_Tensor_out_mps) -(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t) -{ - mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__); +(const Tensor& input_t, const OptionalTensorRef min, const OptionalTensorRef max, const Tensor& output_t) { + mps::clamp_tensor_out_mps(input_t, min, max, output_t, __func__); } TORCH_IMPL_FUNC(clamp_out_mps) -(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t) -{ - mps::clamp_scalar_out_mps(input_t, min, max, const_cast(output_t), "clamp_out_mps"); +(const Tensor& input_t, const OptionalScalarRef min, const OptionalScalarRef max, const Tensor& output_t) { + mps::clamp_scalar_out_mps(input_t, min, max, const_cast(output_t), "clamp_out_mps"); } TORCH_IMPL_FUNC(clamp_min_Tensor_out_mps) -(const Tensor& input_t, const Tensor& min, const Tensor& output_t) -{ - mps::clamp_tensor_out_mps(input_t, min, at::OptionalTensorRef(), output_t, __func__); +(const Tensor& input_t, const Tensor& min, const Tensor& output_t) { + mps::clamp_tensor_out_mps(input_t, min, at::OptionalTensorRef(), output_t, __func__); } TORCH_IMPL_FUNC(clamp_min_out_mps) -(const Tensor& input_t, const Scalar& min, const Tensor& output_t) -{ - mps::clamp_scalar_out_mps(input_t, min, at::OptionalScalarRef(), output_t, __func__); +(const Tensor& input_t, const Scalar& min, const Tensor& output_t) { + mps::clamp_scalar_out_mps(input_t, min, at::OptionalScalarRef(), output_t, __func__); } TORCH_IMPL_FUNC(clamp_max_Tensor_out_mps) -(const Tensor& input_t, const Tensor& max, const Tensor& output_t) -{ - mps::clamp_tensor_out_mps(input_t, at::OptionalTensorRef(), max, output_t, __func__); +(const Tensor& input_t, const Tensor& max, const Tensor& output_t) { + mps::clamp_tensor_out_mps(input_t, at::OptionalTensorRef(), max, output_t, __func__); } TORCH_IMPL_FUNC(clamp_max_out_mps) -(const Tensor& input_t, const Scalar& max, const Tensor& output_t) -{ - mps::clamp_scalar_out_mps(input_t, at::OptionalScalarRef(), max, output_t, __func__); +(const Tensor& input_t, const Scalar& max, const Tensor& output_t) { + mps::clamp_scalar_out_mps(input_t, at::OptionalScalarRef(), max, output_t, __func__); } -Tensor& where_self_out_mps(const Tensor& condition, - const Tensor& self, - const Tensor& other, - Tensor& out) { +Tensor& where_self_out_mps(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) { TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype()); if (condition.scalar_type() == ScalarType::Byte) { - TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead."); + TORCH_WARN_ONCE( + "where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead."); } else { - TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition.scalar_type()); + TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, + "where expected condition to be a boolean tensor, but got a tensor with dtype ", + condition.scalar_type()); } Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition; @@ -305,13 +283,12 @@ void clamp_scalar_out_mps(const Tensor& input_t, MPSStream* stream = getCurrentMPSStream(); // Empty output - if(out.numel() == 0) + if (out.numel() == 0) return out; // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* conditionTensor_ = nil; MPSGraphTensor* selfTensor_ = nil; MPSGraphTensor* otherTensor_ = nil; @@ -326,57 +303,56 @@ void clamp_scalar_out_mps(const Tensor& input_t, // Workaround for `selectWithPredicateTensor` on macOS Monterey where bool data type may cause a hang // The issue is fixed in macOS Ventura (13.0) if (!is_macos_13_or_newer()) { - if (condition.scalar_type() == kBool) { + if (condition.scalar_type() == kBool) { conditionDataType = MPSDataTypeInt8; - } - if (self.scalar_type() == kBool) { + } + if (self.scalar_type() == kBool) { selfDataType = MPSDataTypeInt8; - } - if (other.scalar_type() == kBool) { + } + if (other.scalar_type() == kBool) { otherDataType = MPSDataTypeInt8; - } + } } @autoreleasepool { - string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool)); - MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self)); - MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other)); + MPSGraphTensor* conditionTensor = + mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool)); + MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self)); + MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other)); - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:conditionTensor - truePredicateTensor:selfTensor - falsePredicateTensor:otherTensor - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:conditionTensor + truePredicateTensor:selfTensor + falsePredicateTensor:otherTensor + name:nil]; - newCachedGraph->conditionTensor_ = conditionTensor; - newCachedGraph->selfTensor_ = selfTensor; - newCachedGraph->otherTensor_ = otherTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); + newCachedGraph->conditionTensor_ = conditionTensor; + newCachedGraph->selfTensor_ = selfTensor; + newCachedGraph->otherTensor_ = otherTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder conditionPlaceholder = Placeholder( cachedGraph->conditionTensor_, cond_bool, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, conditionDataType); - Placeholder selfPlaceholder = Placeholder( - cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType); - Placeholder otherPlaceholder = Placeholder( - cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType); + Placeholder selfPlaceholder = + Placeholder(cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType); + Placeholder otherPlaceholder = + Placeholder(cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); NSDictionary* feeds = @{ @@ -384,21 +360,16 @@ void clamp_scalar_out_mps(const Tensor& input_t, selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData() }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } return out; } -Tensor where_mps(const Tensor& condition, - const Tensor& self, - const Tensor& other) { - +Tensor where_mps(const Tensor& condition, const Tensor& self, const Tensor& other) { auto max_dim = std::max(condition.dim(), std::max(self.dim(), other.dim())); // How many leading dimensions do we broadcast across for each Tensor? @@ -409,8 +380,7 @@ Tensor where_mps(const Tensor& condition, std::vector out_arr(max_dim); // Broadcasted output shape - for(int i = 0; i < max_dim; i++) { - + for (int i = 0; i < max_dim; i++) { // Use up the leading broadcast dimensions for each Tensor, then continue from the start of the "actual" shape int64_t cond_idx = i < cond_num_implicit_ones ? 1 : (condition.size(i - cond_num_implicit_ones)); int64_t self_idx = i < self_num_implicit_ones ? 1 : (self.size(i - self_num_implicit_ones)); @@ -418,21 +388,28 @@ Tensor where_mps(const Tensor& condition, auto max_idx = std::max({cond_idx, self_idx, other_idx}); - TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1), i, "'th index ", cond_idx, " of condition tensor does not match the other tensors") - TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1), i, "'th index ", self_idx, " of x tensor does not match the other tensors") - TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1), i, "'th index ", other_idx, " of x tensor does not match the other tensors") + TORCH_CHECK(cond_idx == max_idx || cond_idx == 1 || (cond_idx == 0 && max_idx == 1), + i, + "'th index ", + cond_idx, + " of condition tensor does not match the other tensors") + TORCH_CHECK(self_idx == max_idx || self_idx == 1 || (self_idx == 0 && max_idx == 1), + i, + "'th index ", + self_idx, + " of x tensor does not match the other tensors") + TORCH_CHECK(other_idx == max_idx || other_idx == 1 || (other_idx == 0 && max_idx == 1), + i, + "'th index ", + other_idx, + " of x tensor does not match the other tensors") out_arr[i] = (cond_idx == 0 || self_idx == 0 || other_idx == 0) ? 0 : max_idx; } - Tensor ret = empty_mps(IntArrayRef(out_arr), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - self.suggest_memory_format()); + Tensor ret = empty_mps( + IntArrayRef(out_arr), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, self.suggest_memory_format()); return where_self_out_mps(condition, self, other, ret); - } Tensor& nan_to_num_out_mps(const Tensor& self, @@ -440,8 +417,11 @@ Tensor where_mps(const Tensor& condition, c10::optional pos_inf, c10::optional neg_inf, Tensor& result) { - TORCH_CHECK(self.scalar_type() == result.scalar_type(), "nan_to_num: dtype of out: ", - result.scalar_type(), " should be same as input: ", self.scalar_type()); + TORCH_CHECK(self.scalar_type() == result.scalar_type(), + "nan_to_num: dtype of out: ", + result.scalar_type(), + " should be same as input: ", + self.scalar_type()); if (result.numel() == 0) { return result; } @@ -452,7 +432,7 @@ Tensor where_mps(const Tensor& condition, } using namespace mps; struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* selfTensor = nil; MPSGraphTensor* outputTensor = nil; MPSGraphTensor* nanReplacementTensor = nil; @@ -467,25 +447,27 @@ Tensor where_mps(const Tensor& condition, CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]); - newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]); - newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]); - - MPSGraphTensor* nanFreeTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isNaNWithTensor: newCachedGraph->selfTensor name:nil] - truePredicateTensor: newCachedGraph->nanReplacementTensor - falsePredicateTensor: newCachedGraph->selfTensor - name: nil]; - MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor: nanFreeTensor - secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype] - name: nil]; - MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor: nanFreeTensor name:nil]; + newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]); + newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]); + newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[ @1 ]); + + MPSGraphTensor* nanFreeTensor = + [mpsGraph selectWithPredicateTensor:[mpsGraph isNaNWithTensor:newCachedGraph->selfTensor name:nil] + truePredicateTensor:newCachedGraph->nanReplacementTensor + falsePredicateTensor:newCachedGraph->selfTensor + name:nil]; + MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor:nanFreeTensor + secondaryTensor:[mpsGraph constantWithScalar:0.0 + dataType:self_dtype] + name:nil]; + MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor:nanFreeTensor name:nil]; // workaround for Monterey; On Ventura the output of lessThan() is always Boolean if (subZeroTensor.dataType != MPSDataTypeBool) { subZeroTensor = castMPSTensor(mpsGraph, subZeroTensor, kBool); @@ -493,34 +475,33 @@ Tensor where_mps(const Tensor& condition, if (isInfTensor.dataType != MPSDataTypeBool) { isInfTensor = castMPSTensor(mpsGraph, isInfTensor, kBool); } - MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: subZeroTensor - secondaryTensor: isInfTensor - name: nil]; - MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor: isNegInfTensor - truePredicateTensor: newCachedGraph->negInfReplacementTensor - falsePredicateTensor: nanFreeTensor - name: nil]; - newCachedGraph->outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isInfiniteWithTensor: negInfFreeTensor name:nil] - truePredicateTensor: newCachedGraph->posInfReplacementTensor - falsePredicateTensor: negInfFreeTensor - name: nil]; + MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor:subZeroTensor + secondaryTensor:isInfTensor + name:nil]; + MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor:isNegInfTensor + truePredicateTensor:newCachedGraph->negInfReplacementTensor + falsePredicateTensor:nanFreeTensor + name:nil]; + newCachedGraph->outputTensor = + [mpsGraph selectWithPredicateTensor:[mpsGraph isInfiniteWithTensor:negInfFreeTensor name:nil] + truePredicateTensor:newCachedGraph->posInfReplacementTensor + falsePredicateTensor:negInfFreeTensor + name:nil]; } return newCachedGraph; }); } MPSScalar nanReplacementScalar, posInfReplacementScalar, negInfReplacementScalar; AT_DISPATCH_FLOATING_TYPES_AND(kHalf, self.scalar_type(), "nan_to_num_mps", [&]() { - scalar_t nan_replacement = static_cast(nan.value_or(0.)); - scalar_t pos_inf_replacement = pos_inf.has_value() ? - static_cast(pos_inf.value()) : - std::numeric_limits::max(); - scalar_t neg_inf_replacement = neg_inf.has_value() ? - static_cast(neg_inf.value()) : - std::numeric_limits::lowest(); - - nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type()); - posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type()); - negInfReplacementScalar = getMPSScalar(neg_inf_replacement, self.scalar_type()); + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = + pos_inf.has_value() ? static_cast(pos_inf.value()) : std::numeric_limits::max(); + scalar_t neg_inf_replacement = + neg_inf.has_value() ? static_cast(neg_inf.value()) : std::numeric_limits::lowest(); + + nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type()); + posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type()); + negInfReplacementScalar = getMPSScalar(neg_inf_replacement, self.scalar_type()); }); MPSStream* stream = getCurrentMPSStream(); @@ -528,14 +509,13 @@ Tensor where_mps(const Tensor& condition, Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, result); NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->nanReplacementTensor : getMPSGraphTensorFromScalar(stream, nanReplacementScalar), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + cachedGraph->nanReplacementTensor : getMPSGraphTensorFromScalar(stream, nanReplacementScalar), cachedGraph->posInfReplacementTensor : getMPSGraphTensorFromScalar(stream, posInfReplacementScalar), cachedGraph->negInfReplacementTensor : getMPSGraphTensorFromScalar(stream, negInfReplacementScalar), }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return result; diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm index a4b0db98b0fc7f..921ba8dec74fa8 100644 --- a/aten/src/ATen/native/mps/operations/TriangularOps.mm +++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm @@ -14,10 +14,7 @@ namespace at::native { TORCH_IMPL_FUNC(triu_mps_out) -(const Tensor& self, - int64_t k, - const Tensor &output) { - +(const Tensor& self, int64_t k, const Tensor& output) { using namespace mps; if (self.numel() == 0) { @@ -26,22 +23,21 @@ MPSStream* stream = getCurrentMPSStream(); // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -50,12 +46,10 @@ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* outputTensor = nil; - MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 - dataType:MPSDataTypeInt32]; + MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; - if(k > 0) { - MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k-1) - dataType:MPSDataTypeInt32]; + if (k > 0) { + MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32]; MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor numLowerTensor:minusOneTensor numUpperTensor:diagMinusOneTensor @@ -63,10 +57,8 @@ outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:complementTensor name:nil]; - } - else { - MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) - dataType:MPSDataTypeInt32]; + } else { + MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32]; outputTensor = [mpsGraph bandPartWithTensor:inputTensor numLowerTensor:minusDiagTensor numUpperTensor:minusOneTensor @@ -78,29 +70,23 @@ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } TORCH_IMPL_FUNC(tril_mps_out) -(const Tensor& self, - int64_t k, - const Tensor &output) { - +(const Tensor& self, int64_t k, const Tensor& output) { using namespace mps; if (self.numel() == 0) { @@ -109,22 +95,21 @@ MPSStream* stream = getCurrentMPSStream(); // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -133,20 +118,16 @@ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* outputTensor = nil; - MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 - dataType:MPSDataTypeInt32]; + MPSGraphTensor* minusOneTensor = [mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt32]; - if(k >= 0) { - MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k - dataType:MPSDataTypeInt32]; + if (k >= 0) { + MPSGraphTensor* diagTensor = [mpsGraph constantWithScalar:k dataType:MPSDataTypeInt32]; outputTensor = [mpsGraph bandPartWithTensor:inputTensor numLowerTensor:minusOneTensor numUpperTensor:diagTensor name:nil]; - } - else { - MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k-1) - dataType:MPSDataTypeInt32]; + } else { + MPSGraphTensor* negDiagMinusOneTensor = [mpsGraph constantWithScalar:(-k - 1) dataType:MPSDataTypeInt32]; MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor numLowerTensor:negDiagMinusOneTensor numUpperTensor:minusOneTensor @@ -161,22 +142,19 @@ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 444aaf2c7ec16d..b5f3976a725624 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -1,7 +1,7 @@ // Copyright © 2022 Apple Inc. -#include #include +#include namespace at::native { namespace mps { @@ -9,14 +9,16 @@ typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*); using is_noop_p = std::function; - bool is_empty_tensor(const Tensor& self) { return self.numel() == 0; } -void unary_op(const Tensor& self, const Tensor& output, std::string op_name, UnaryOpBlock unaryBlock, is_noop_p is_noop = is_empty_tensor) -{ - TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte ), +void unary_op(const Tensor& self, + const Tensor& output, + std::string op_name, + UnaryOpBlock unaryBlock, + is_noop_p is_noop = is_empty_tensor) { + TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte), "MPS support unary op with uint8 natively starting from macOS 13.0"); if (!output.is_same_size(self)) { output.resize_(self.sizes()); @@ -30,9 +32,9 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una string key = op_name + getTensorsStringKey({self, output}); auto cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph* () { - MPSUnaryCachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + MPSUnaryCachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new MPSUnaryCachedGraph(mpsGraph); @@ -55,18 +57,15 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, /*mpsShape=*/nullptr, gatherTensorData); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false); - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* feeds = + @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()}; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } } -MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) -{ +MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { // Rounding is a no-op for integral types, and also a reasonable workaround // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` // See https://github.com/pytorch/pytorch/issues/84995 @@ -75,100 +74,91 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una return inputTensor; } - if(!is_macos_13_or_newer()) { - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - dataType:inputTensor.dataType]; + if (!is_macos_13_or_newer()) { + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor - name:nil]; + name:nil]; return [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil] + truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil] falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil] name:nil]; } else { - return [mpsGraph truncateWithTensor:inputTensor - name:nil]; + return [mpsGraph truncateWithTensor:inputTensor name:nil]; } }; MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - dataType:inputTensor.dataType]; - MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor - secondaryTensor:oneTensor - name:nil]; - return [mpsGraph logarithmWithTensor:addedTensor - name:nil]; + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:inputTensor.dataType]; + MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor secondaryTensor:oneTensor name:nil]; + return [mpsGraph logarithmWithTensor:addedTensor name:nil]; } } // namespace mps -TORCH_IMPL_FUNC(trunc_out_mps) (const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "trunc_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) - { return mps::trunc_tensor(mpsGraph, inputTensor); }); +TORCH_IMPL_FUNC(trunc_out_mps)(const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "trunc_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return mps::trunc_tensor(mpsGraph, inputTensor); + }); } -TORCH_IMPL_FUNC(signbit_out_mps) (const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "signbit_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - MPSGraphTensor* output; - // signbit is not implemented for int64 type. - // workaround for `Function signbitOp_i64 was not found in the library` - if ([inputTensor dataType] == MPSDataTypeInt64) { - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; - output = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:zeroTensor - name:nil]; - } else { - output = [mpsGraph signbitWithTensor: inputTensor name: nil]; - } - return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool); - }); +TORCH_IMPL_FUNC(signbit_out_mps)(const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "signbit_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* output; + // signbit is not implemented for int64 type. + // workaround for `Function signbitOp_i64 was not found in the library` + if ([inputTensor dataType] == MPSDataTypeInt64) { + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; + output = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil]; + } else { + output = [mpsGraph signbitWithTensor:inputTensor name:nil]; + } + return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool); + }); } -TORCH_IMPL_FUNC(sign_out_mps) (const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "sign_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - // Sign op is not implemented in MPS as of MacOS13.0 beta, so simulate it using clamp - if ([inputTensor dataType] == MPSDataTypeInt64) { - return [mpsGraph clampWithTensor:inputTensor - minValueTensor:[mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt64] - maxValueTensor:[mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt64] - name: nil]; - } - return [mpsGraph signWithTensor: inputTensor name: nil]; - }); +TORCH_IMPL_FUNC(sign_out_mps)(const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "sign_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + // Sign op is not implemented in MPS as of MacOS13.0 beta, so simulate it using clamp + if ([inputTensor dataType] == MPSDataTypeInt64) { + return [mpsGraph clampWithTensor:inputTensor + minValueTensor:[mpsGraph constantWithScalar:-1 dataType:MPSDataTypeInt64] + maxValueTensor:[mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt64] + name:nil]; + } + return [mpsGraph signWithTensor:inputTensor name:nil]; + }); } -#define CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(func_out, func_stub) \ -TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& output) { \ - mps::unary_op(self, output, #func_out, \ - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \ - { return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }, \ - [](const Tensor& t) -> bool { \ - return t.numel() == 0 || isIntegralType(t.scalar_type(), true); \ - }); \ -} +#define CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(func_out, func_stub) \ + TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \ + mps::unary_op( \ + self, \ + output, \ + #func_out, \ + ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \ + return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \ + }, \ + [](const Tensor& t) -> bool { return t.numel() == 0 || isIntegralType(t.scalar_type(), true); }); \ + } CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(ceil_out_mps, ceil) CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(floor_out_mps, floor) CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(round_out_mps, round) -#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \ -TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& output) { \ - mps::unary_op(self, output, #func_out, \ - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \ - { return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \ -} - -#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \ -Tensor& func_out(const Tensor& self, Tensor& output) { \ - mps::unary_op(self, output, #func_out, \ - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \ - { return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }); \ - return output; \ -} +#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \ + TORCH_IMPL_FUNC(func_out)(const Tensor& self, const Tensor& output) { \ + mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \ + return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \ + }); \ + } +#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \ + Tensor& func_out(const Tensor& self, Tensor& output) { \ + mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \ + return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \ + }); \ + return output; \ + } CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2) @@ -195,139 +185,104 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute) -Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) -{ +Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) { auto bool_self = self.to(ScalarType::Bool); - mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor){ return [mpsGraph notWithTensor:inputTensor name:nil];}); + mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return [mpsGraph notWithTensor:inputTensor name:nil]; + }); return output; } -TORCH_IMPL_FUNC(sigmoid_out_mps) (const Tensor& self, const Tensor& output) -{ +TORCH_IMPL_FUNC(sigmoid_out_mps)(const Tensor& self, const Tensor& output) { TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support sigmoid op with int64 input"); - mps::unary_op(self, output, "sigmoid_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - return [mpsGraph sigmoidWithTensor:inputTensor name:nil]; - }); + mps::unary_op(self, output, "sigmoid_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return [mpsGraph sigmoidWithTensor:inputTensor name:nil]; + }); } -TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output) -{ +TORCH_IMPL_FUNC(log1p_out_mps)(const Tensor& self, const Tensor& output) { TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input"); - mps::unary_op(self, output, "log1p_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - return mps::log1p(mpsGraph, inputTensor); - }); + mps::unary_op(self, output, "log1p_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return mps::log1p(mpsGraph, inputTensor); + }); } -TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output) { +TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) { TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types"); - mps::unary_op(self, output, "frac_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto zeroTensor = [mpsGraph constantWithScalar:0.0 - dataType:inputTensor.dataType]; - auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:zeroTensor - name:nil]; - auto truncTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil] - falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil] - name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:inputTensor - secondaryTensor:truncTensor - name: nil]; - }); + mps::unary_op(self, output, "frac_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + auto zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; + auto predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil]; + auto truncTensor = [mpsGraph selectWithPredicateTensor:predicateTensor + truePredicateTensor:[mpsGraph ceilWithTensor:inputTensor name:nil] + falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil] + name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:truncTensor name:nil]; + }); } -TORCH_IMPL_FUNC(expm1_out_mps) (const Tensor& self, const Tensor& output) { - mps::unary_op(self, output, "expm1_out_mps", - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor - name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:ePowTensor - secondaryTensor:oneTensor - name: nil]; - }); +TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "expm1_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:ePowTensor secondaryTensor:oneTensor name:nil]; + }); } void logit_mps_impl(const Tensor& self, c10::optional eps, Tensor& output, const std::string op_name) { std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]"; - mps::unary_op(self, output, key, - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* logitInputTensor; - - if (eps.has_value()) { - MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value() - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: lowTensor - name: nil]; - logitInputTensor = [mpsGraph clampWithTensor:inputTensor - minValueTensor:lowTensor - maxValueTensor:highTensor - name:nil]; - } else { - logitInputTensor = inputTensor; - } + mps::unary_op(self, output, key, ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* logitInputTensor; + + if (eps.has_value()) { + MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps.value() shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor secondaryTensor:lowTensor name:nil]; + logitInputTensor = [mpsGraph clampWithTensor:inputTensor + minValueTensor:lowTensor + maxValueTensor:highTensor + name:nil]; + } else { + logitInputTensor = inputTensor; + } - MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: logitInputTensor - name: nil]; - MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor - secondaryTensor:oneMinusInputTensor - name:nil]; - return [mpsGraph logarithmWithTensor:outputTensor - name:nil]; - }); + MPSGraphTensor* oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor + secondaryTensor:logitInputTensor + name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor + secondaryTensor:oneMinusInputTensor + name:nil]; + return [mpsGraph logarithmWithTensor:outputTensor name:nil]; + }); } -Tensor& logit_out_mps(const Tensor& self, - c10::optional eps, - Tensor& result) { +Tensor& logit_out_mps(const Tensor& self, c10::optional eps, Tensor& result) { logit_mps_impl(self, eps, result, "logit_out_mps"); return result; } Tensor logit_mps(const Tensor& self, c10::optional eps) { - Tensor result = at::native::empty_mps( - self.sizes(), - ScalarType::Float, - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + Tensor result = + at::native::empty_mps(self.sizes(), ScalarType::Float, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); logit_mps_impl(self, eps, result, "logit_mps"); return result; } -TORCH_IMPL_FUNC(logit_backward_out_mps) ( - const Tensor& grad_output, - const Tensor& input, - c10::optional eps, - const Tensor& grad_input) - { +TORCH_IMPL_FUNC(logit_backward_out_mps) +(const Tensor& grad_output, const Tensor& input, c10::optional eps, const Tensor& grad_input) { using namespace mps; // Empty output - if(grad_input.numel() == 0) + if (grad_input.numel() == 0) return; double eps_ = eps ? eps.value() : -1.0; - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @@ -335,14 +290,13 @@ Tensor logit_mps(const Tensor& self, c10::optional eps) { MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + - "[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]"; - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" + "[" + + (eps.has_value() ? std::to_string(eps.value()) : "-1") + "]"; - CachedGraph *newCachedGraph = nil; + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -351,40 +305,32 @@ Tensor logit_mps(const Tensor& self, c10::optional eps) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor - secondaryTensor: lowTensor - name: nil]; - MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: lowTensor - name: nil]; - MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor - secondaryTensor: highTensor - name: nil]; - MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor - secondaryTensor: inputGreaterThanHighPredicateTensor - name: nil]; - MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor - secondaryTensor: inputTensor - name: nil]; + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ shape:@[ @1 ] dataType:inputTensor.dataType]; + MPSGraphTensor* inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:lowTensor + name:nil]; + MPSGraphTensor* highTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor + secondaryTensor:lowTensor + name:nil]; + MPSGraphTensor* inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor + secondaryTensor:highTensor + name:nil]; + MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor:inputLessThanLowPredicateTensor + secondaryTensor:inputGreaterThanHighPredicateTensor + name:nil]; + MPSGraphTensor* oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor:oneTensor + secondaryTensor:inputTensor + name:nil]; outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:oneMinusInputTensor name:nil]; - outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor - secondaryTensor:outputTensor + outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor secondaryTensor:outputTensor name:nil]; + outputTensor = [mpsGraph selectWithPredicateTensor:outOfIntervalTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:outputTensor name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor - truePredicateTensor: zeroTensor - falsePredicateTensor: outputTensor - name: nil]; newCachedGraph->gradOutputTensor_ = gradOutputTensor; newCachedGraph->inputTensor_ = inputTensor; @@ -392,7 +338,7 @@ Tensor logit_mps(const Tensor& self, c10::optional eps) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); @@ -403,25 +349,25 @@ Tensor logit_mps(const Tensor& self, c10::optional eps) { gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary* results = + @{gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } - - TORCH_IMPL_FUNC(cumsum_out_mps) -(const Tensor& self, - int64_t dim, - c10::optional dtype, - const Tensor& result) { - +(const Tensor& self, int64_t dim, c10::optional dtype, const Tensor& result) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); auto nDims = self.dim(); auto wrapped_dim = maybe_wrap_dim(dim, nDims); - TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")"); + TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()), + "Expected wrapped dim to be between 0 and ", + self.ndimension(), + " but got ", + wrapped_dim, + "(original dim is ", + dim, + ")"); if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade"); auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype); @@ -430,29 +376,27 @@ Tensor logit_mps(const Tensor& self, c10::optional eps) { } auto input = dtype.has_value() ? self.to(dtype.value()) : self; - // issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32 - // fixed in macOS 13.3 - bool castInputData = (isIntegralType(input.scalar_type()) && - input.scalar_type() != ScalarType::Int && + // issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to + // int32 fixed in macOS 13.3 + bool castInputData = (isIntegralType(input.scalar_type()) && input.scalar_type() != ScalarType::Int && input.scalar_type() != ScalarType::Long); TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3"); - mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim), - ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - - if (castInputData) { - inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int); - } - auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor - axis: dim - name: nil]; - if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) { - return mps::castMPSTensor(mpsGraph, rc, result.scalar_type()); - } - return rc; - }); + mps::unary_op(input, + result, + "cumsum_out_mp" + std::to_string(dim), + ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + if (castInputData) { + inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int); + } + auto rc = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil]; + if ((mps::getMPSDataType(result) != [rc dataType]) || castInputData) { + return mps::castMPSTensor(mpsGraph, rc, result.scalar_type()); + } + return rc; + }); } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm index eac16a74564ee9..3927cb1eb7e2b8 100644 --- a/aten/src/ATen/native/mps/operations/Unique.mm +++ b/aten/src/ATen/native/mps/operations/Unique.mm @@ -1,15 +1,14 @@ // Copyright © 2022 Apple Inc. -#include -#include #include +#include +#include namespace at::native { namespace mps { -struct UniqueCachedGraph : public MPSCachedGraph -{ - UniqueCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} +struct UniqueCachedGraph : public MPSCachedGraph { + UniqueCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; MPSGraphTensor* inverseIndicesTensor_ = nil; @@ -17,230 +16,201 @@ MPSGraphTensor* lengthTensor_ = nil; }; -static std::string getUniqueKey(const ScalarType& dtype, const IntArrayRef& base_shape, - const bool return_inverse, const bool return_counts, - const bool consecutive, c10::optional dimOpt) -{ - return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + - "]:[" + (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + - "]:[" + to_string(return_counts) + "]:[" + to_string(consecutive) + "]"; +static std::string getUniqueKey(const ScalarType& dtype, + const IntArrayRef& base_shape, + const bool return_inverse, + const bool return_counts, + const bool consecutive, + c10::optional dimOpt) { + return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" + + (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" + + to_string(return_counts) + "]:[" + to_string(consecutive) + "]"; } // dim arg not supported when non consecutive, ie sorted -std::array buildUniqueGraph(const Tensor& self, UniqueCachedGraph *uniqueGraph, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional dimOpt) { +std::array buildUniqueGraph(const Tensor& self, + UniqueCachedGraph* uniqueGraph, + const bool return_inverse, + const bool return_counts, + const bool consecutive, + c10::optional dimOpt) { int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0; - MPSGraph *graph = uniqueGraph->graph(); - MPSGraphTensor *inputTensor = uniqueGraph->inputTensor_; - MPSShape *shape = [inputTensor shape]; - MPSShape *destShape = shape; + MPSGraph* graph = uniqueGraph->graph(); + MPSGraphTensor* inputTensor = uniqueGraph->inputTensor_; + MPSShape* shape = [inputTensor shape]; + MPSShape* destShape = shape; uint64_t length = [shape[dim] unsignedIntValue]; MPSDataType dataType = [inputTensor dataType]; - MPSGraphTensor *resultTensor = nil; - MPSGraphTensor *inverseIndicesTensor = nil; - MPSGraphTensor *countTensor = nil; - MPSGraphTensor *lengthTensor = nil; + MPSGraphTensor* resultTensor = nil; + MPSGraphTensor* inverseIndicesTensor = nil; + MPSGraphTensor* countTensor = nil; + MPSGraphTensor* lengthTensor = nil; if (length <= 1) { // Trivial case, only 1 element everything is unique resultTensor = inputTensor; - lengthTensor = [graph constantWithScalar:0.0f - dataType:MPSDataTypeInt32]; + lengthTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32]; if (return_inverse) { - inverseIndicesTensor = [graph constantWithScalar:0.0f - dataType:MPSDataTypeInt32]; + inverseIndicesTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32]; } if (return_counts) { - countTensor = [graph constantWithScalar:1.0f - dataType:MPSDataTypeInt32]; + countTensor = [graph constantWithScalar:1.0f dataType:MPSDataTypeInt32]; } return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor}; } // #issue 104398441 sortWithTensor only supports following types, cast if necessary - if (dataType != MPSDataTypeInt32 && - dataType != MPSDataTypeFloat32 && - dataType != MPSDataTypeFloat16) { + if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) { dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; - inputTensor = [graph castTensor:inputTensor - toType:dataType - name:@"castInputTensor"]; + inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; } bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1); if (needsFlatten) { - inputTensor = [graph reshapeTensor:inputTensor - withShape:@[@-1] - name:nil]; + inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil]; length = 1; - for (const auto i: c10::irange([shape count])) { + for (const auto i : c10::irange([shape count])) { if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) { TORCH_CHECK(false, "RuntimeError: Tensor size overflow"); } } - destShape = @[[NSNumber numberWithUnsignedInteger:length]]; + destShape = @[ [NSNumber numberWithUnsignedInteger:length] ]; } - MPSGraphTensor *sortedInput = nil; + MPSGraphTensor* sortedInput = nil; if (consecutive) { sortedInput = inputTensor; } else { - sortedInput = [graph sortWithTensor:inputTensor - axis:0 - name:nil]; + sortedInput = [graph sortWithTensor:inputTensor axis:0 name:nil]; } - MPSGraphTensor *frontNMinusOne = [graph sliceTensor:sortedInput - dimension:dim - start:0 - length:length-1 - name:nil]; - MPSGraphTensor *backNMinusOne = [graph sliceTensor:sortedInput - dimension:dim - start:1 - length:length-1 - name:nil]; - MPSGraphTensor *notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne + MPSGraphTensor* frontNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:0 length:length - 1 name:nil]; + MPSGraphTensor* backNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:1 length:length - 1 name:nil]; + MPSGraphTensor* notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne secondaryTensor:frontNMinusOne name:nil]; - MPSGraphTensor *mask = [graph castTensor:notEqualToPreviousElement - toType:MPSDataTypeInt32 - name:@"castMaskTensor"]; + MPSGraphTensor* mask = [graph castTensor:notEqualToPreviousElement toType:MPSDataTypeInt32 name:@"castMaskTensor"]; // If comparing tensors, not scalars, check if entire tensor matches previos element using reductionOr over tensor if (dimOpt.has_value() && [shape count] != 1) { - NSMutableArray *axes = [[NSMutableArray alloc] initWithCapacity:[shape count]-1]; + NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1]; for (const auto axis : c10::irange([shape count])) { if (axis != dim) { [axes addObject:[NSNumber numberWithUnsignedInteger:axis]]; } } - mask = [graph reductionOrWithTensor:mask - axes:axes - name:nil]; - mask = [graph squeezeTensor:mask - axes:axes - name:nil]; + mask = [graph reductionOrWithTensor:mask axes:axes name:nil]; + mask = [graph squeezeTensor:mask axes:axes name:nil]; [axes release]; } - MPSGraphTensor *scannedIndices = [graph cumulativeSumWithTensor:mask - axis:0 - name:nil]; - lengthTensor = [graph sliceTensor:scannedIndices - dimension:0 - start:length-2 - length:1 - name:nil]; - - MPSGraphTensor *minusOneTensor = [graph constantWithScalar:-1.0f - dataType:MPSDataTypeInt32]; - MPSGraphTensor *maskedIndices = [graph selectWithPredicateTensor:mask + MPSGraphTensor* scannedIndices = [graph cumulativeSumWithTensor:mask axis:0 name:nil]; + lengthTensor = [graph sliceTensor:scannedIndices dimension:0 start:length - 2 length:1 name:nil]; + + MPSGraphTensor* minusOneTensor = [graph constantWithScalar:-1.0f dataType:MPSDataTypeInt32]; + MPSGraphTensor* maskedIndices = [graph selectWithPredicateTensor:mask truePredicateTensor:scannedIndices falsePredicateTensor:minusOneTensor name:nil]; - MPSGraphTensor *zeroTensor = [graph constantWithScalar:0.0f - shape:@[@1] - dataType:MPSDataTypeInt32]; - MPSGraphTensor *maskedIndicesWithHead = [graph concatTensors:@[zeroTensor, maskedIndices] - dimension:0 - name:nil]; - MPSGraphTensor *scannedIndicesWithHead = [graph concatTensors:@[zeroTensor, scannedIndices] - dimension:0 - name:nil]; + MPSGraphTensor* zeroTensor = [graph constantWithScalar:0.0f shape:@[ @1 ] dataType:MPSDataTypeInt32]; + MPSGraphTensor* maskedIndicesWithHead = [graph concatTensors:@[ zeroTensor, maskedIndices ] dimension:0 name:nil]; + MPSGraphTensor* scannedIndicesWithHead = [graph concatTensors:@[ zeroTensor, scannedIndices ] dimension:0 name:nil]; resultTensor = [graph scatterWithUpdatesTensor:sortedInput - indicesTensor:maskedIndicesWithHead - shape:destShape - axis:dim - mode:MPSGraphScatterModeSet - name:nil]; + indicesTensor:maskedIndicesWithHead + shape:destShape + axis:dim + mode:MPSGraphScatterModeSet + name:nil]; // Cast back if necessary if ([uniqueGraph->inputTensor_ dataType] != dataType) { - resultTensor = [graph castTensor:resultTensor - toType:[uniqueGraph->inputTensor_ dataType] - name:@"castResultTensor"]; + resultTensor = [graph castTensor:resultTensor toType:[uniqueGraph->inputTensor_ dataType] name:@"castResultTensor"]; } // Compute optional returned tensors if requested - if(return_inverse) { - MPSGraphTensor *argSortedInput = nil; + if (return_inverse) { + MPSGraphTensor* argSortedInput = nil; if (consecutive) argSortedInput = [graph coordinateAlongAxis:0 - withShape:@[[NSNumber numberWithUnsignedInteger:length]] + withShape:@[ [NSNumber numberWithUnsignedInteger:length] ] name:nil]; else - argSortedInput = [graph argSortWithTensor:inputTensor - axis:0 - name:nil]; + argSortedInput = [graph argSortWithTensor:inputTensor axis:0 name:nil]; inverseIndicesTensor = [graph scatterWithUpdatesTensor:scannedIndicesWithHead - indicesTensor:argSortedInput - shape:@[[NSNumber numberWithUnsignedInteger:length]] - axis:0 - mode:MPSGraphScatterModeAdd - name:nil]; + indicesTensor:argSortedInput + shape:@[ [NSNumber numberWithUnsignedInteger:length] ] + axis:0 + mode:MPSGraphScatterModeAdd + name:nil]; if (needsFlatten) - inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor - withShape:shape - name:nil]; + inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor withShape:shape name:nil]; } if (return_counts) { - MPSGraphTensor *unitTensor = [graph constantWithScalar:1.0f - shape:@[[NSNumber numberWithUnsignedInteger:length]] + MPSGraphTensor* unitTensor = [graph constantWithScalar:1.0f + shape:@[ [NSNumber numberWithUnsignedInteger:length] ] dataType:MPSDataTypeInt32]; countTensor = [graph scatterWithUpdatesTensor:unitTensor - indicesTensor:scannedIndicesWithHead - shape:@[[NSNumber numberWithUnsignedInteger:length]] - axis:0 - mode:MPSGraphScatterModeAdd - name:nil]; + indicesTensor:scannedIndicesWithHead + shape:@[ [NSNumber numberWithUnsignedInteger:length] ] + axis:0 + mode:MPSGraphScatterModeAdd + name:nil]; } return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor}; } -static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional dim) { +static UniqueCachedGraph* getUniqueGraph(const Tensor& self, + const bool return_inverse, + const bool return_counts, + const bool consecutive, + c10::optional dim) { MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim); - UniqueCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - UniqueCachedGraph *newCachedGraph = nil; - - @autoreleasepool { - // Initialize graph - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new UniqueCachedGraph(mpsGraph); - - // Workaround for MPSShaderLibrary bug - // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved - auto inputType = getMPSScalarType(self.scalar_type()); - newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self.sizes())); - - auto outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim); - - newCachedGraph->outputTensor_ = outputTensors[0]; - newCachedGraph->inverseIndicesTensor_ = outputTensors[1]; - newCachedGraph->countsTensor_ = outputTensors[2]; - newCachedGraph->lengthTensor_ = outputTensors[3]; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } + UniqueCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + UniqueCachedGraph* newCachedGraph = nil; + + @autoreleasepool { + // Initialize graph + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new UniqueCachedGraph(mpsGraph); + + // Workaround for MPSShaderLibrary bug + // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved + auto inputType = getMPSScalarType(self.scalar_type()); + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self.sizes())); + + auto outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim); + + newCachedGraph->outputTensor_ = outputTensors[0]; + newCachedGraph->inverseIndicesTensor_ = outputTensors[1]; + newCachedGraph->countsTensor_ = outputTensors[2]; + newCachedGraph->lengthTensor_ = outputTensors[3]; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } return cachedGraph; } } -void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& output, - Tensor& inverse_indices, Tensor& counts, Tensor& length, - bool return_inverse, bool return_counts){ +void runUniqueGraph(UniqueCachedGraph* uniqueGraph, + const Tensor& input, + Tensor& output, + Tensor& inverse_indices, + Tensor& counts, + Tensor& length, + bool return_inverse, + bool return_counts) { Placeholder inputPlaceholder = Placeholder(uniqueGraph->inputTensor_, input); NSDictionary* feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), @@ -249,10 +219,8 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& NSMutableDictionary* results = [NSMutableDictionary dictionary]; Placeholder outputPlaceholder = Placeholder(uniqueGraph->outputTensor_, output); Placeholder lengthPlaceholder = Placeholder(uniqueGraph->lengthTensor_, length); - [results setObject:outputPlaceholder.getMPSGraphTensorData() - forKey:outputPlaceholder.getMPSGraphTensor()]; - [results setObject:lengthPlaceholder.getMPSGraphTensorData() - forKey:lengthPlaceholder.getMPSGraphTensor()]; + [results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()]; + [results setObject:lengthPlaceholder.getMPSGraphTensorData() forKey:lengthPlaceholder.getMPSGraphTensor()]; if (return_inverse) { Placeholder inverseIndicesPlaceholder = Placeholder(uniqueGraph->inverseIndicesTensor_, inverse_indices); [results setObject:inverseIndicesPlaceholder.getMPSGraphTensorData() @@ -260,8 +228,7 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& } if (return_counts) { Placeholder countsPlaceholder = Placeholder(uniqueGraph->countsTensor_, counts); - [results setObject:countsPlaceholder.getMPSGraphTensorData() - forKey:countsPlaceholder.getMPSGraphTensor()]; + [results setObject:countsPlaceholder.getMPSGraphTensorData() forKey:countsPlaceholder.getMPSGraphTensor()]; } // Run the graph @@ -271,9 +238,11 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& } // namespace mps -std::tuple -_unique_impl_mps(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional dimOpt) { - +std::tuple _unique_impl_mps(const Tensor& self, + const bool return_inverse, + const bool return_counts, + const bool consecutive, + c10::optional dimOpt) { const Tensor& input = self.contiguous(); // get flat output size @@ -303,7 +272,7 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& return std::make_tuple(output, inverse_indices, counts); } - mps::UniqueCachedGraph *uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt); + mps::UniqueCachedGraph* uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt); mps::runUniqueGraph(uniqueGraph, input, output, inverse_indices, counts, length, return_inverse, return_counts); int64_t lengthScalar = length.item() + 1; // length actually holds max index, add 1 @@ -316,17 +285,14 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& return std::make_tuple(output, inverse_indices, counts); } - -static -std::tuple castToMPS(std::tuple out) { - return std::make_tuple( - get<0>(out).to("mps"), - get<1>(out).to("mps"), - get<2>(out).to("mps")); +static std::tuple castToMPS(std::tuple out) { + return std::make_tuple(get<0>(out).to("mps"), get<1>(out).to("mps"), get<2>(out).to("mps")); } -std::tuple -unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional dim) { +std::tuple unique_consecutive_mps(const Tensor& self, + const bool return_inverse, + const bool return_counts, + c10::optional dim) { if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("MPS: unique_consecutive op is supported natively starting from macOS 13.0. ", "Falling back on CPU. This may have performace implications."); @@ -336,8 +302,10 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& return _unique_impl_mps(self, return_inverse, return_counts, true, dim); } -std::tuple -unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_inverse, const bool return_counts) { +std::tuple unique_dim_consecutive_mps(const Tensor& self, + int64_t dim, + const bool return_inverse, + const bool return_counts) { if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("MPS: unique_dim_consecutive op is supported natively starting from macOS 13.0. ", "Falling back on CPU. This may have performace implications."); @@ -347,8 +315,10 @@ void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& return _unique_impl_mps(self, return_inverse, return_counts, true, c10::make_optional((int64_t)dim)); } -std::tuple -_unique2_mps(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { +std::tuple _unique2_mps(const Tensor& self, + const bool sorted, + const bool return_inverse, + const bool return_counts) { if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("MPS: _unique2 op is supported natively starting from macOS 13.0. ", "Falling back on CPU. This may have performace implications."); diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index 4e7a06ab616837..4bf01110dc11b9 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -1,8 +1,8 @@ // Copyright © 2023 Apple Inc. -#include -#include #include +#include +#include namespace at::native { namespace mps { @@ -20,7 +20,7 @@ void upsample_out_template(const Tensor& input, if (input.numel() == 0) { return; } - const auto input_dim = input.sizes(); + const auto input_dim = input.sizes(); if (input_dim.size() <= 3) { native::upsample_1d_common_check(input.sizes(), output_size); } else { @@ -34,9 +34,8 @@ void upsample_out_template(const Tensor& input, bool centerResults = false; MPSGraphResizeMode resizeMode = MPSGraphResizeNearest; MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor; - MPSGraphTensorNamedDataLayout dataLayout = input_dim.size() > 3 ? - MPSGraphTensorNamedDataLayoutNCHW : - MPSGraphTensorNamedDataLayoutCHW; + MPSGraphTensorNamedDataLayout dataLayout = + input_dim.size() > 3 ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutCHW; if (resize_mode_str == "nearest") { resizeMode = MPSGraphResizeNearest; } else if (resize_mode_str == "bilinear") { @@ -50,7 +49,7 @@ void upsample_out_template(const Tensor& input, } const bool is_macOS_13_0_or_newer = is_macos_13_or_newer(); - const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0]; + const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0]; const int64_t output_height = output_size.size() > 1 ? output_size[0] : 1; const float scale_w = (scale_w_opt.value_or(0.) > 0.) ? static_cast(scale_w_opt.value()) : 0.; const float scale_h = (scale_h_opt.value_or(0.) > 0.) ? static_cast(scale_h_opt.value()) : 1.; @@ -63,37 +62,37 @@ void upsample_out_template(const Tensor& input, input_size = input_size_opt.value(); } struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor = nil, *outputTensor = nil; - MPSGraphTensor *outputSizeTensor = nil; + MPSGraphTensor* outputSizeTensor = nil; }; MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") + - getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" + - (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; + getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" + + (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(2)]); + newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(2) ]); MPSGraphTensor* scaleOffsetTensor = nullptr; MPSGraphTensor* inputSizeTensor = nullptr; if (scale_w > 0.0) { const float outScales[4] = {scale_h, scale_w, offset_y, offset_x}; - scaleOffsetTensor = [mpsGraph constantWithData: [NSData dataWithBytes: outScales length: sizeof(outScales)] - shape: @[@4] - dataType: MPSDataTypeFloat32]; + scaleOffsetTensor = [mpsGraph constantWithData:[NSData dataWithBytes:outScales length:sizeof(outScales)] + shape:@[ @4 ] + dataType:MPSDataTypeFloat32]; } if (is_backward_pass) { std::vector inputSizeVec(4); @@ -101,118 +100,119 @@ void upsample_out_template(const Tensor& input, inputSizeVec[1] = @(input_size[1]); inputSizeVec[2] = @(input_size[2]); inputSizeVec[3] = @(input_dim.size() > 3 ? input_size[3] : 1); - inputSizeTensor = [mpsGraph constantWithScalar: 0 - shape: [NSArray arrayWithObjects:inputSizeVec.data() count:input_dim.size()] - dataType: getMPSDataType(input)]; + inputSizeTensor = [mpsGraph constantWithScalar:0 + shape:[NSArray arrayWithObjects:inputSizeVec.data() + count:input_dim.size()] + dataType:getMPSDataType(input)]; } if (is_macOS_13_0_or_newer) { if (!is_backward_pass) { if (scaleOffsetTensor && !align_corners) { if (resizeMode == MPSGraphResizeNearest) { - newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor - sizeTensor: newCachedGraph->outputSizeTensor - scaleOffsetTensor: scaleOffsetTensor - nearestRoundingMode: nearestRoundingMode - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor + sizeTensor:newCachedGraph->outputSizeTensor + scaleOffsetTensor:scaleOffsetTensor + nearestRoundingMode:nearestRoundingMode + layout:dataLayout + name:nil]; } else { // bilinear forward - newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor - sizeTensor: newCachedGraph->outputSizeTensor - scaleOffsetTensor: scaleOffsetTensor - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor + sizeTensor:newCachedGraph->outputSizeTensor + scaleOffsetTensor:scaleOffsetTensor + layout:dataLayout + name:nil]; } } else { // scaleOffsetTensor == nil || align_corners if (resizeMode == MPSGraphResizeNearest) { - newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor - sizeTensor: newCachedGraph->outputSizeTensor - nearestRoundingMode: nearestRoundingMode - centerResult: centerResults - alignCorners: align_corners - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor:newCachedGraph->inputTensor + sizeTensor:newCachedGraph->outputSizeTensor + nearestRoundingMode:nearestRoundingMode + centerResult:centerResults + alignCorners:align_corners + layout:dataLayout + name:nil]; } else { // bilinear forward - newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor - sizeTensor: newCachedGraph->outputSizeTensor - centerResult: centerResults - alignCorners: align_corners - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor:newCachedGraph->inputTensor + sizeTensor:newCachedGraph->outputSizeTensor + centerResult:centerResults + alignCorners:align_corners + layout:dataLayout + name:nil]; } } } else { // is_backward_pass == true if (scaleOffsetTensor && !align_corners) { if (resizeMode == MPSGraphResizeNearest) { - newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor - input: inputSizeTensor - scaleOffsetTensor: scaleOffsetTensor - nearestRoundingMode: nearestRoundingMode - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor:newCachedGraph->inputTensor + input:inputSizeTensor + scaleOffsetTensor:scaleOffsetTensor + nearestRoundingMode:nearestRoundingMode + layout:dataLayout + name:nil]; } else { // bilinear backward - newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor - input: inputSizeTensor - scaleOffsetTensor: scaleOffsetTensor - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor:newCachedGraph->inputTensor + input:inputSizeTensor + scaleOffsetTensor:scaleOffsetTensor + layout:dataLayout + name:nil]; } } else { // scaleOffsetTensor == nil || align_corners if (resizeMode == MPSGraphResizeNearest) { - newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor - input: inputSizeTensor - nearestRoundingMode: nearestRoundingMode - centerResult: centerResults - alignCorners: align_corners - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor:newCachedGraph->inputTensor + input:inputSizeTensor + nearestRoundingMode:nearestRoundingMode + centerResult:centerResults + alignCorners:align_corners + layout:dataLayout + name:nil]; } else { // bilinear backward - newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor - input: inputSizeTensor - centerResult: centerResults - alignCorners: align_corners - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor:newCachedGraph->inputTensor + input:inputSizeTensor + centerResult:centerResults + alignCorners:align_corners + layout:dataLayout + name:nil]; } } } } else { // if macOS version < 13.0 (for backwards compatibility) if (!is_backward_pass) { - newCachedGraph->outputTensor = [mpsGraph resizeTensor: newCachedGraph->inputTensor - sizeTensor: newCachedGraph->outputSizeTensor - mode: resizeMode - centerResult: centerResults - alignCorners: align_corners - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeTensor:newCachedGraph->inputTensor + sizeTensor:newCachedGraph->outputSizeTensor + mode:resizeMode + centerResult:centerResults + alignCorners:align_corners + layout:dataLayout + name:nil]; } else { - newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->inputTensor - input: inputSizeTensor - mode: resizeMode - centerResult: centerResults - alignCorners: align_corners - layout: dataLayout - name: nil]; + newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor:newCachedGraph->inputTensor + input:inputSizeTensor + mode:resizeMode + centerResult:centerResults + alignCorners:align_corners + layout:dataLayout + name:nil]; } } } return newCachedGraph; }); } - MPSNDArrayDescriptor *sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(2)]]; - MPSNDArray *sizeNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: sizeDesc] autorelease]; - [sizeNDArray writeBytes: (int32_t[]) {(int32_t)output_height, (int32_t)output_width} strideBytes: nil]; - MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease]; + MPSNDArrayDescriptor* sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(2) ]]; + MPSNDArray* sizeNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:sizeDesc] autorelease]; + [sizeNDArray writeBytes:(int32_t[]){(int32_t)output_height, (int32_t)output_width} strideBytes:nil]; + MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:sizeNDArray] autorelease]; - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false); NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - cachedGraph->outputSizeTensor : sizeTensorData, - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + cachedGraph->outputSizeTensor : sizeTensorData, }; + NSDictionary* results = + @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); if (out.has_storage()) { @@ -223,8 +223,7 @@ void upsample_out_template(const Tensor& input, } // namespace mps -static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional scale) -{ +static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional scale) { static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer(); if (!is_macOS_13_0_or_newer) { // passing scale factors to MPS's resize APIs is not supported on macOS < 13 @@ -232,11 +231,13 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: TORCH_WARN_ONCE("MPS: passing scale factor to upsample ops is supported natively starting from macOS 13.0. ", "Falling back on CPU. This may have performance implications."); return false; - // nearest mode on Monterey uses round() to compute source indices which - // is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey. - // The nearest mode should work fine on Ventura. + // nearest mode on Monterey uses round() to compute source indices which + // is incompatible with PyTorch that uses floor(). So we fallback to CPU on Monterey. + // The nearest mode should work fine on Ventura. } else if (resize_mode_str == "nearest" || resize_mode_str == "nearest-exact") { - TORCH_WARN_ONCE("MPS: '", resize_mode_str, "' mode upsampling is supported natively starting from macOS 13.0. ", + TORCH_WARN_ONCE("MPS: '", + resize_mode_str, + "' mode upsampling is supported natively starting from macOS 13.0. ", "Falling back on CPU. This may have performance implications."); return false; } @@ -244,12 +245,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: return true; } -TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scale, - const Tensor& output) -{ +TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) +(const Tensor& input, IntArrayRef output_size, c10::optional scale, const Tensor& output) { if (check_mps_compatibility("nearest", scale)) { mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest"); } else { @@ -258,27 +255,23 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: } } -TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scale, - const Tensor& grad_input) -{ +TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scale, + const Tensor& grad_input) { if (check_mps_compatibility("nearest", scale)) { mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest"); } else { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(grad_input) = at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps"); + const_cast(grad_input) = + at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps"); } } -TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scale, - const Tensor& output) -{ +TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) +(const Tensor& input, IntArrayRef output_size, c10::optional scale, const Tensor& output) { if (check_mps_compatibility("nearest-exact", scale)) { mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact"); } else { @@ -287,113 +280,123 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: } } -TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scale, - const Tensor& grad_input) -{ +TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scale, + const Tensor& grad_input) { if (check_mps_compatibility("nearest-exact", scale)) { - mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact"); + mps::upsample_out_template( + grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact"); } else { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(grad_input) = at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps"); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(grad_input) = + at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps"); } } -TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output) -{ +TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) +(const Tensor& input, + IntArrayRef output_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& output) { if (check_mps_compatibility("nearest", scales_w)) { mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest"); } else { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(output) = at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps"); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(output) = + at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps"); } } -TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input) -{ +TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& grad_input) { if (check_mps_compatibility("nearest", scales_w)) { mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest"); } else { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(grad_input) = at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps"); + const_cast(grad_input) = + at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w) + .clone() + .to("mps"); } } -TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output) -{ +TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) +(const Tensor& input, + IntArrayRef output_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& output) { if (check_mps_compatibility("nearest-exact", scales_w)) { mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact"); } else { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(output) = at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps"); + const_cast(output) = + at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps"); } } -TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input) -{ +TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& grad_input) { if (check_mps_compatibility("nearest-exact", scales_w)) { - mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact"); + mps::upsample_out_template( + grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact"); } else { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(grad_input) = at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps"); + const_cast(grad_input) = + at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w) + .clone() + .to("mps"); } } -TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output) -{ +TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) +(const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& output) { if (check_mps_compatibility("bilinear", scales_w)) { mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear"); } else { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(output) = at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps"); + const_cast(output) = + at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps"); } } -TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input) -{ +TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& grad_input) { if (check_mps_compatibility("bilinear", scales_w)) { - mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear"); + mps::upsample_out_template( + grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear"); } else { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(grad_input) = at::upsample_bilinear2d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w).clone().to("mps"); + const_cast(grad_input) = + at::upsample_bilinear2d_backward( + grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w) + .clone() + .to("mps"); } } diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index b6df6e3f654fe5..41b05ce9d2da41 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -1,18 +1,17 @@ // Copyright © 2022 Apple Inc. -#include -#include +#include #include +#include +#include #include #include -#include namespace at::native { namespace mps { -struct ViewCachedGraph : public MPSCachedGraph -{ - ViewCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} +struct ViewCachedGraph : public MPSCachedGraph { + ViewCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor = nil; MPSGraphTensor* outputTensor = nil; MPSGraphTensor* updatesTensor = nil; @@ -20,18 +19,20 @@ std::vector strideTensors; }; -static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType& updates_dtype, const IntArrayRef& base_shape, - const IntArrayRef& new_shape, const IntArrayRef& stride, - int64_t storage_offset, bool is_scatter) -{ +static std::string getStridedKey(const ScalarType& self_dtype, + const ScalarType& updates_dtype, + const IntArrayRef& base_shape, + const IntArrayRef& new_shape, + const IntArrayRef& stride, + int64_t storage_offset, + bool is_scatter) { std::string dtype_key = getMPSTypeString(self_dtype); if (is_scatter) { dtype_key += ":" + getMPSTypeString(updates_dtype); } - return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + - getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]:[" + - getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]"; + return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" + + getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]"; } // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op @@ -39,30 +40,31 @@ const id sourceBuffer = getMTLBufferStorage(src); const id outputBuffer = getMTLBufferStorage(output); - const IntArrayRef& strides = needsScatter ? output.strides() : src.strides(); - const IntArrayRef& sizes = needsScatter ? output.sizes() : src.sizes(); + const IntArrayRef& strides = needsScatter ? output.strides() : src.strides(); + const IntArrayRef& sizes = needsScatter ? output.sizes() : src.sizes(); const int64_t storage_offset = needsScatter ? output.storage_offset() : src.storage_offset(); - const MPSDataType inputType = [cachedGraph->inputTensor dataType]; + const MPSDataType inputType = [cachedGraph->inputTensor dataType]; - MPSShape *inputShape = [cachedGraph->inputTensor shape]; - MPSShape *outputShape = needsScatter ? inputShape : getMPSShape(src); + MPSShape* inputShape = [cachedGraph->inputTensor shape]; + MPSShape* outputShape = needsScatter ? inputShape : getMPSShape(src); MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; // in case of scatter, we use output tensor as input buffer and write the results back to the source buffer - feeds[cachedGraph->inputTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: needsScatter ? outputBuffer : sourceBuffer - shape: inputShape - dataType: inputType] autorelease]; + feeds[cachedGraph->inputTensor] = + [[[MPSGraphTensorData alloc] initWithMTLBuffer:needsScatter ? outputBuffer : sourceBuffer + shape:inputShape + dataType:inputType] autorelease]; if (needsScatter) { auto updatesType = getMPSScalarType(src.scalar_type()); if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) { updatesType = MPSDataTypeInt8; } - feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer - shape: getMPSShape(src.numel()) - dataType: updatesType] autorelease]; + feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer:sourceBuffer + shape:getMPSShape(src.numel()) + dataType:updatesType] autorelease]; } MPSScalar storageOffsetScalar = getMPSScalar(storage_offset, ScalarType::Int); feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar); @@ -75,59 +77,53 @@ // Workaround for MPSShaderLibrary bug in macOS Monterey // This is fixed in macOS Ventura auto outputType = getMPSScalarType(output.scalar_type()); - if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) { - outputType = MPSDataTypeInt8; + if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) { + outputType = MPSDataTypeInt8; } - MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer - shape: outputShape - dataType: outputType] autorelease]; - NSDictionary* results = @{ - cachedGraph->outputTensor : outputTensorData - }; + MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:outputBuffer + shape:outputShape + dataType:outputType] autorelease]; + NSDictionary* results = @{cachedGraph->outputTensor : outputTensorData}; runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return output; } -MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSArray *permuteOrder) { +MPSGraphTensor* permuteTensor(MPSGraph* graph, MPSGraphTensor* inputTensor, NSArray* permuteOrder) { NSUInteger srcRank = [[inputTensor shape] count]; if (srcRank != [permuteOrder count]) { return nil; } - MPSGraphTensor *outputTensor = inputTensor; + MPSGraphTensor* outputTensor = inputTensor; std::vector dimensionOrder(srcRank); - std::iota (std::begin(dimensionOrder), std::end(dimensionOrder), 0); + std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0); - for (const auto i : c10::irange(srcRank)) { + for (const auto i : c10::irange(srcRank)) { NSUInteger axis = [permuteOrder[i] integerValue]; auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis); NSUInteger axis1 = i; NSUInteger axis2 = axisIter - dimensionOrder.begin(); iter_swap(dimensionOrder.begin() + i, axisIter); - outputTensor = [graph transposeTensor:outputTensor - dimension:axis1 - withDimension:axis2 - name:nil]; + outputTensor = [graph transposeTensor:outputTensor dimension:axis1 withDimension:axis2 name:nil]; } return outputTensor; } -NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger rank, NSUInteger offset) { +NSDictionary* getStrideToDimLengthOffsetDict(MPSGraphTensor* tensor, NSUInteger rank, NSUInteger offset) { // Assuming input tensor has default strides NSInteger stride = 1; - NSMutableDictionary *strideToDimLengthOffset = [[NSMutableDictionary alloc] init]; + NSMutableDictionary* strideToDimLengthOffset = [[NSMutableDictionary alloc] init]; for (NSInteger srcDim = rank - 1; srcDim >= 0; srcDim--) { NSUInteger size = [[tensor shape][srcDim] integerValue]; - NSDictionary *entry = - @{ - @"dim": [NSNumber numberWithInteger:srcDim], - @"length": [tensor shape][srcDim], - @"offset": [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride + NSDictionary* entry = @{ + @"dim" : [NSNumber numberWithInteger:srcDim], + @"length" : [tensor shape][srcDim], + @"offset" : [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride }; - [strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld",stride]]; + [strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld", stride]]; offset /= size; stride *= size; } @@ -135,14 +131,18 @@ } // Detect only expand dims, allows for duplicate strides -MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { - +MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph, + MPSGraphTensor* inputTensor, + int dstRank, + const IntArrayRef& dstSizes, + const IntArrayRef& dstStrides, + int offset) { NSUInteger srcRank = [[inputTensor shape] count]; // Not an expand dims if (srcRank >= dstRank) return nil; - NSMutableArray *expandAxes = [[NSMutableArray alloc] init]; + NSMutableArray* expandAxes = [[NSMutableArray alloc] init]; BOOL isValidExpand = YES; NSInteger currSrcDim = (NSInteger)srcRank - 1; @@ -152,7 +152,7 @@ NSUInteger currStride = dstStrides[dstDim]; NSUInteger currSrcDimLength = currSrcDim >= 0 ? [[inputTensor shape][currSrcDim] integerValue] : 1; - NSUInteger targetDimLength = currSrcDimLength; + NSUInteger targetDimLength = currSrcDimLength; if (currDimLength != targetDimLength) { targetDimLength = 1; } @@ -173,11 +173,9 @@ return nil; } - MPSGraphTensor *expandTensor = inputTensor; + MPSGraphTensor* expandTensor = inputTensor; if ([expandAxes count]) { - expandTensor = [graph expandDimsOfTensor:expandTensor - axes:expandAxes - name:nil]; + expandTensor = [graph expandDimsOfTensor:expandTensor axes:expandAxes name:nil]; } [expandAxes release]; @@ -185,13 +183,18 @@ } // Detect contiguous reshapes, no slicing -MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { +MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph, + MPSGraphTensor* inputTensor, + int dstRank, + const IntArrayRef& dstSizes, + const IntArrayRef& dstStrides, + int offset) { NSUInteger srcRank = [[inputTensor shape] count]; // Not a reshape if (srcRank <= dstRank) return nil; - NSMutableArray *dstShape = [[NSMutableArray alloc] init]; + NSMutableArray* dstShape = [[NSMutableArray alloc] init]; BOOL isValidReshape = YES; NSInteger srcDim = srcRank - 1; @@ -199,7 +202,7 @@ for (NSInteger dstDim = dstRank - 1; dstDim >= 0 && isValidReshape; dstDim--) { NSUInteger currDimLength = dstSizes[dstDim]; NSUInteger currStride = dstStrides[dstDim]; - [dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex: 0]; + [dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex:0]; NSUInteger targetDimLength = currDimLength; NSUInteger currReshapeSize = 1; @@ -216,26 +219,28 @@ } isValidReshape &= (srcDim < 0); - MPSGraphTensor *outputTensor = nil; + MPSGraphTensor* outputTensor = nil; if (isValidReshape) - outputTensor = [graph reshapeTensor: inputTensor - withShape: dstShape - name: nil]; + outputTensor = [graph reshapeTensor:inputTensor withShape:dstShape name:nil]; [dstShape release]; return outputTensor; } -MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { - +MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph, + MPSGraphTensor* inputTensor, + int dstRank, + const IntArrayRef& dstSizes, + const IntArrayRef& dstStrides, + int offset) { // Duplicate strides cannot be done { BOOL allUnique = YES; - NSMutableSet *uniqueStrides = [[NSMutableSet alloc] init]; + NSMutableSet* uniqueStrides = [[NSMutableSet alloc] init]; for (NSInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) { int stride = dstStrides[dstDim]; - NSNumber *strideObj = [NSNumber numberWithInt:stride]; + NSNumber* strideObj = [NSNumber numberWithInt:stride]; allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]); - [uniqueStrides addObject: strideObj]; + [uniqueStrides addObject:strideObj]; } [uniqueStrides release]; if (!allUnique) @@ -243,31 +248,31 @@ // Skip for zero in dst shape for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) - if (dstSizes[dstDim] == 0) { return nil; } + if (dstSizes[dstDim] == 0) { + return nil; + } } // 1. Flatten the inputTensor if necessary - MPSGraphTensor *flatInputTensor = inputTensor; + MPSGraphTensor* flatInputTensor = inputTensor; { // Flatten inputs to remove duplicate strides. - NSMutableArray *squeezeAxes = [[NSMutableArray alloc] init]; - for(NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) { - if ([[flatInputTensor shape][srcDim] intValue] == 1) - [squeezeAxes addObject:[NSNumber numberWithInteger:srcDim]]; + NSMutableArray* squeezeAxes = [[NSMutableArray alloc] init]; + for (NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) { + if ([[flatInputTensor shape][srcDim] intValue] == 1) + [squeezeAxes addObject:[NSNumber numberWithInteger:srcDim]]; } // We have to leave at least 1 dimension, if all input dims are 1 if ([squeezeAxes count]) - flatInputTensor = [graph squeezeTensor:flatInputTensor - axes:squeezeAxes - name:nil]; + flatInputTensor = [graph squeezeTensor:flatInputTensor axes:squeezeAxes name:nil]; [squeezeAxes release]; } int srcRank = (int)[[flatInputTensor shape] count]; - NSDictionary *srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset); + NSDictionary* srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset); // Populate the dimension order, slice info, and broadcast info - NSMutableArray *dstDimOrder = [[NSMutableArray alloc] init]; + NSMutableArray* dstDimOrder = [[NSMutableArray alloc] init]; std::vector dstDimToSliceLength(dstRank); std::vector dstDimToSliceOffset(dstRank); bool needsBroadcast = false; @@ -280,31 +285,33 @@ dstDimToSliceOffset[dstDim] = 0; } else { // Find what dimension and native length was for the specified stride - NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld",dstStrides[dstDim]]]; + NSDictionary* srcDimLengthOffset = + srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld", dstStrides[dstDim]]]; dstDimToSliceLength[dstDim] = dstSizes[dstDim]; dstDimToSliceOffset[dstDim] = [srcDimLengthOffset[@"offset"] intValue]; // Stride does not exist in source tensor, or the specified size is too long. Not possible - // TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding support + // TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding + // support if (!srcDimLengthOffset || // the offset + length of destination should not be larger than source's length when slicing dstDimToSliceOffset[dstDim] + dstDimToSliceLength[dstDim] > [srcDimLengthOffset[@"length"] intValue]) { return nil; } // Get the src dimension corresponding to the requested stride - NSNumber *srcDim = srcDimLengthOffset[@"dim"]; + NSNumber* srcDim = srcDimLengthOffset[@"dim"]; [dstDimOrder insertObject:srcDim atIndex:0]; } } } // 2. Slice out any unused dimensions - NSMutableArray *missingSrcDims = [[NSMutableArray alloc] init]; - MPSGraphTensor *slicedUnusedTensor = flatInputTensor; + NSMutableArray* missingSrcDims = [[NSMutableArray alloc] init]; + MPSGraphTensor* slicedUnusedTensor = flatInputTensor; { // Find any src strides/dims that are not present in the dst - NSMutableArray *missingSrcStrides = [[NSMutableArray alloc] init]; + NSMutableArray* missingSrcStrides = [[NSMutableArray alloc] init]; { NSUInteger stride = 1; for (NSInteger srcDim = [[flatInputTensor shape] count] - 1; srcDim >= 0; srcDim--) { @@ -317,8 +324,8 @@ } for (NSUInteger i = 0; i < [missingSrcStrides count]; i++) { NSUInteger stride = [missingSrcStrides[i] integerValue]; - NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld",stride]]; - NSNumber *missingSrcDim = srcDimLengthOffset[@"dim"]; + NSDictionary* srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld", stride]]; + NSNumber* missingSrcDim = srcDimLengthOffset[@"dim"]; [missingSrcDims addObject:missingSrcDim]; [dstDimOrder insertObject:missingSrcDim atIndex:0]; @@ -332,35 +339,33 @@ } // 3. Transpose if necessary - MPSGraphTensor *transposedTensor = slicedUnusedTensor; + MPSGraphTensor* transposedTensor = slicedUnusedTensor; { // TODO: Use Transpose API BOOL needsTranspose = NO; - for(NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++ ) + for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++) needsTranspose |= ([dstDimOrder[dstDim] intValue] != dstDim); if (needsTranspose) transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder); } // 4. Squeeze any unused dimensions following transpose - MPSGraphTensor *squeezedTensor = transposedTensor; + MPSGraphTensor* squeezedTensor = transposedTensor; { // Transpose the missing dims back - NSMutableArray *transposedMissingSrcDims = [[NSMutableArray alloc] init]; + NSMutableArray* transposedMissingSrcDims = [[NSMutableArray alloc] init]; for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count]; dstDim++) { - NSNumber *srcDim = dstDimOrder[dstDim]; + NSNumber* srcDim = dstDimOrder[dstDim]; if ([missingSrcDims containsObject:srcDim]) [transposedMissingSrcDims addObject:[NSNumber numberWithInt:dstDim]]; } if ([transposedMissingSrcDims count]) - squeezedTensor = [graph squeezeTensor:squeezedTensor - axes:transposedMissingSrcDims - name:nil]; + squeezedTensor = [graph squeezeTensor:squeezedTensor axes:transposedMissingSrcDims name:nil]; [transposedMissingSrcDims release]; } // 5. Slice - MPSGraphTensor *slicedTensor = squeezedTensor; + MPSGraphTensor* slicedTensor = squeezedTensor; { NSUInteger currDstDim = 0; for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) { @@ -369,34 +374,26 @@ int start = dstDimToSliceOffset[dstDim]; int length = dstDimToSliceLength[dstDim]; if (length != [[slicedTensor shape][currDstDim] intValue]) - slicedTensor = [graph sliceTensor:slicedTensor - dimension:currDstDim - start:start - length:length - name:nil]; + slicedTensor = [graph sliceTensor:slicedTensor dimension:currDstDim start:start length:length name:nil]; currDstDim++; } } } // 6. Expand then broadcast the source tensor - MPSGraphTensor *broadcastTensor = slicedTensor; + MPSGraphTensor* broadcastTensor = slicedTensor; if (needsBroadcast) { - NSMutableArray *broadcastShape = [[NSMutableArray alloc] init]; - NSMutableArray *expandAxes = [[NSMutableArray alloc] init]; - for(NSInteger dstDim = 0; dstDim < dstRank; dstDim++) { + NSMutableArray* broadcastShape = [[NSMutableArray alloc] init]; + NSMutableArray* expandAxes = [[NSMutableArray alloc] init]; + for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) { [broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]]; if (dstStrides[dstDim] == 0) [expandAxes addObject:[NSNumber numberWithInt:dstDim]]; } if ([expandAxes count]) { - MPSGraphTensor *expandTensor = [graph expandDimsOfTensor:broadcastTensor - axes:expandAxes - name:nil]; - broadcastTensor = [graph broadcastTensor:expandTensor - toShape:broadcastShape - name:nil]; + MPSGraphTensor* expandTensor = [graph expandDimsOfTensor:broadcastTensor axes:expandAxes name:nil]; + broadcastTensor = [graph broadcastTensor:expandTensor toShape:broadcastShape name:nil]; } [broadcastShape release]; [expandAxes release]; @@ -409,11 +406,16 @@ return broadcastTensor; } -MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { +MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph, + MPSGraphTensor* inputTensor, + int dstRank, + const IntArrayRef& dstSizes, + const IntArrayRef& dstStrides, + int offset) { if (!dstRank) return nil; - MPSGraphTensor *outputTensor = nil; + MPSGraphTensor* outputTensor = nil; outputTensor = asStridedLayer_expandDimsPattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset); if (!outputTensor) outputTensor = asStridedLayer_reshapePattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset); @@ -423,8 +425,7 @@ return outputTensor; } -static -std::vector getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) { +static std::vector getViewShape(const Tensor& src, MPSShape* mpsShape, const bool squeeze) { bool hasMPSShape = (mpsShape != nil); std::vector src_view_shape; if (hasMPSShape) { @@ -459,7 +460,6 @@ return src_view_shape; } - std::vector getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) { std::vector src_base_shape; for (const auto i : c10::irange(shape.size())) { @@ -471,8 +471,7 @@ return src_base_shape; } - -bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) { +bool canSliceViewTensor(const Tensor& src, MPSShape* mpsShape) { if (!src.is_contiguous()) { return false; } @@ -486,23 +485,23 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) { return false; } - for (const auto i: c10::irange(src_ndim_base)) { - if (src_view_shape[i] > src_base_shape[i]) { - return false; - } - } + for (const auto i : c10::irange(src_ndim_base)) { + if (src_view_shape[i] > src_base_shape[i]) { + return false; + } + } return true; } -MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) { +MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape* mpsShape, const MPSDataType mpsDataType) { IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data()); size_t src_ndim_base = src_base_shape.size(); std::vector src_view_shape = getViewShape(src, mpsShape, false); size_t src_ndim_view = src_view_shape.size(); - MPSNDArray *srcTensorNDArrayView = nil; - MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil; - MPSNDArray *srcTensorNDArray = nil; + MPSNDArray* srcTensorNDArrayView = nil; + MPSNDArrayDescriptor* srcTensorNDArrayDesc = nil; + MPSNDArray* srcTensorNDArray = nil; id commandBuffer = getCurrentMPSStream()->commandBuffer(); int64_t base_idx = 0; @@ -537,19 +536,21 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) { } int64_t sliceOffset = src.storage_offset() / view_numel; - [srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice - withSubrange:{static_cast(sliceOffset), static_cast(src.sizes()[firstDimToSlice])}]; + [srcTensorNDArrayDesc + sliceDimension:src_ndim_base - 1 - firstDimToSlice + withSubrange:{static_cast(sliceOffset), static_cast(src.sizes()[firstDimToSlice])}]; // Slice any remaining dimensions - for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) { + for (const auto crtSliceOffset : c10::irange(firstDimToSlice + 1, src_base_shape.size())) { if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) { if (crtSliceOffset == src_base_shape.size() - 1) { sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1]; } else { sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]); } - [srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset - withSubrange:{static_cast(sliceOffset), static_cast(src.sizes()[crtSliceOffset])}]; + [srcTensorNDArrayDesc + sliceDimension:src_ndim_base - 1 - crtSliceOffset + withSubrange:{static_cast(sliceOffset), static_cast(src.sizes()[crtSliceOffset])}]; } } srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer @@ -559,13 +560,15 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) { return [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcTensorNDArrayView] autorelease]; } -static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size, - const IntArrayRef& stride, int64_t offset, - const IntArrayRef& base_shape, bool needsScatter, - MPSGraphTensor* updatesTensor) -{ +static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, + const IntArrayRef& size, + const IntArrayRef& stride, + int64_t offset, + const IntArrayRef& base_shape, + bool needsScatter, + MPSGraphTensor* updatesTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor *outputTensor = nil; + MPSGraphTensor* outputTensor = nil; const size_t shape_size = size.size(); @autoreleasepool { @@ -575,87 +578,74 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) { TORCH_CHECK(size[i] <= int_max); sizeArray[i] = static_cast(size[i]); } - NSData* shapeData = [NSData dataWithBytes: sizeArray.data() - length: shape_size * sizeof(int32_t)]; - MPSGraphTensor* shapeTensor = [mpsGraph constantWithData: shapeData - shape: @[[NSNumber numberWithUnsignedInteger: shape_size]] - dataType: MPSDataTypeInt32]; + NSData* shapeData = [NSData dataWithBytes:sizeArray.data() length:shape_size * sizeof(int32_t)]; + MPSGraphTensor* shapeTensor = [mpsGraph constantWithData:shapeData + shape:@[ [NSNumber numberWithUnsignedInteger:shape_size] ] + dataType:MPSDataTypeInt32]; MPSGraphTensor* indicesTensor = nil; // create stride Tensors for each rank of the input tensor for (int i = 0; i < shape_size; i++) { - MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis: (-i - 1) - withShapeTensor: shapeTensor - name: nil]; + MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil]; MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1]; - MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor: rangeTensor - secondaryTensor: strideTensor - name: nil]; + MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor + secondaryTensor:strideTensor + name:nil]; if (!indicesTensor) { indicesTensor = indexTensor; } else { - indicesTensor = [mpsGraph additionWithPrimaryTensor: indexTensor - secondaryTensor: indicesTensor - name: nil]; + indicesTensor = [mpsGraph additionWithPrimaryTensor:indexTensor secondaryTensor:indicesTensor name:nil]; } } - indicesTensor = [mpsGraph additionWithPrimaryTensor: indicesTensor - secondaryTensor: cachedGraph->storageOffsetTensor - name: nil]; - MPSGraphTensor *inputTensor = cachedGraph->inputTensor; + indicesTensor = [mpsGraph additionWithPrimaryTensor:indicesTensor + secondaryTensor:cachedGraph->storageOffsetTensor + name:nil]; + MPSGraphTensor* inputTensor = cachedGraph->inputTensor; if (!needsScatter) { - MPSGraphTensor *outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset); + MPSGraphTensor* outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset); if (outputTensor) { return outputTensor; } } - MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor: inputTensor - withShape: @[@-1] - name: nil]; - MPSGraphTensor *reshapedIndicesTensor = [mpsGraph reshapeTensor: indicesTensor - withShape: @[@-1] - name: nil]; + MPSGraphTensor* reshapedInputTensor = [mpsGraph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil]; + MPSGraphTensor* reshapedIndicesTensor = [mpsGraph reshapeTensor:indicesTensor withShape:@[ @-1 ] name:nil]; if (needsScatter) { #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wobjc-method-access" - MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis: (NSInteger) 0 - withDataTensor: reshapedInputTensor - updatesTensor: updatesTensor - indicesTensor: reshapedIndicesTensor - mode: MPSGraphScatterModeSet - name: nil]; + MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis:(NSInteger)0 + withDataTensor:reshapedInputTensor + updatesTensor:updatesTensor + indicesTensor:reshapedIndicesTensor + mode:MPSGraphScatterModeSet + name:nil]; #pragma clang diagnostic pop - outputTensor = [mpsGraph reshapeTensor: scatteredTensor - withShape: getMPSShape(base_shape) - name: nil]; + outputTensor = [mpsGraph reshapeTensor:scatteredTensor withShape:getMPSShape(base_shape) name:nil]; } else { // Call gather to coalesce the needed values. Result will be of same shape as flattened indices tensor - MPSGraphTensor *gatheredTensor = [mpsGraph gatherWithUpdatesTensor: reshapedInputTensor - indicesTensor: reshapedIndicesTensor - axis: 0 - batchDimensions: 0 - name: nil]; + MPSGraphTensor* gatheredTensor = [mpsGraph gatherWithUpdatesTensor:reshapedInputTensor + indicesTensor:reshapedIndicesTensor + axis:0 + batchDimensions:0 + name:nil]; // Reshape the data to desired size - outputTensor = [mpsGraph reshapeTensor: gatheredTensor - withShapeTensor: shapeTensor - name: nil]; + outputTensor = [mpsGraph reshapeTensor:gatheredTensor withShapeTensor:shapeTensor name:nil]; } } return outputTensor; } -static IntArrayRef updateTensorBaseShape(const Tensor& self) -{ +static IntArrayRef updateTensorBaseShape(const Tensor& self) { IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data()); // if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it if (base_shape.size() == 0) { // IntArrayRef wouldn't own the data, so we use a static storage static const int64_t shape_1d = 1; // self.sizes().size() could be zero - base_shape = self.sizes().size() ? self.sizes() : - ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1)); + base_shape = self.sizes().size() + ? self.sizes() + : ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1)); // base_shape will be retained in MPSAllocator until buffer gets recycled if (self.storage().data()) @@ -681,49 +671,53 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) // | / \ | // | / \ | // NonView T NonView T -static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &updates, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter) -{ +static ViewCachedGraph* createViewGraph(const Tensor& self, + const Tensor& updates, + IntArrayRef size, + IntArrayRef stride, + int64_t storage_offset, + bool needsScatter) { IntArrayRef base_shape = updateTensorBaseShape(self); @autoreleasepool { - string key = getStridedKey(self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter); + string key = getStridedKey( + self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - ViewCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + ViewCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - ViewCachedGraph *newCachedGraph = nil; + cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + ViewCachedGraph* newCachedGraph = nil; @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - MPSGraphTensor* updatesTensor = nil; - newCachedGraph = new ViewCachedGraph(mpsGraph); - // Workaround for MPSShaderLibrary bug in macOS Monterey - // This is fixed in macOS Ventura - auto inputType = getMPSScalarType(self.scalar_type()); - if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) { - inputType = MPSDataTypeInt8; - } - - // Self is the input tensor we are creating view of - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape)); - newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]); - for (int i = 0; i < size.size(); i++) { - newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1])); + MPSGraph* mpsGraph = make_mps_graph(); + MPSGraphTensor* updatesTensor = nil; + newCachedGraph = new ViewCachedGraph(mpsGraph); + // Workaround for MPSShaderLibrary bug in macOS Monterey + // This is fixed in macOS Ventura + auto inputType = getMPSScalarType(self.scalar_type()); + if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) { + inputType = MPSDataTypeInt8; + } + + // Self is the input tensor we are creating view of + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape)); + newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]); + for (int i = 0; i < size.size(); i++) { + newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ])); + } + if (needsScatter) { + auto updatesType = getMPSScalarType(updates.scalar_type()); + if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) { + updatesType = MPSDataTypeInt8; } - if (needsScatter) { - auto updatesType = getMPSScalarType(updates.scalar_type()); - if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) { - updatesType = MPSDataTypeInt8; - } - newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel())); - updatesTensor = newCachedGraph->updatesTensor; - if (inputType != updatesType) { - updatesTensor = [mpsGraph castTensor:updatesTensor - toType:inputType - name:@"castUpdatesTensor"]; - } + newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel())); + updatesTensor = newCachedGraph->updatesTensor; + if (inputType != updatesType) { + updatesTensor = [mpsGraph castTensor:updatesTensor toType:inputType name:@"castUpdatesTensor"]; } - newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor); + } + newCachedGraph->outputTensor = + chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor); } return newCachedGraph; })); @@ -732,11 +726,7 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) } } -static -std::string getGatherScatterFunctionName( - ScalarType scalarType, - int64_t dim, - bool needsScatter) { +static std::string getGatherScatterFunctionName(ScalarType scalarType, int64_t dim, bool needsScatter) { std::string kernelName = needsScatter ? "scatter" : "gather"; return kernelName + "_kernel_" + std::to_string(dim == 0 ? 1 : dim); } @@ -744,14 +734,14 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) const std::string& getGatherScatterScalarType(const Tensor& t) { auto scalar_type = t.scalar_type(); static std::unordered_map scalarToMetalType = { - {c10::ScalarType::Float, "float"}, - {c10::ScalarType::Half, "half"}, - {c10::ScalarType::Long, "long"}, - {c10::ScalarType::Int, "int"}, - {c10::ScalarType::Short, "short"}, - {c10::ScalarType::Char, "char"}, - {c10::ScalarType::Byte, "uchar"}, - {c10::ScalarType::Bool, "bool"}, + {c10::ScalarType::Float, "float"}, + {c10::ScalarType::Half, "half"}, + {c10::ScalarType::Long, "long"}, + {c10::ScalarType::Int, "int"}, + {c10::ScalarType::Short, "short"}, + {c10::ScalarType::Char, "char"}, + {c10::ScalarType::Byte, "uchar"}, + {c10::ScalarType::Bool, "bool"}, }; auto it = scalarToMetalType.find(scalar_type); @@ -759,24 +749,30 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) return it->second; } -static -id compileGatherScatterOpsLibrary(id device, - const std::string& dtypeSrc, - const std::string& dtypeDst, - bool needsScatter) { +static id compileGatherScatterOpsLibrary(id device, + const std::string& dtypeSrc, + const std::string& dtypeDst, + bool needsScatter) { auto key = std::to_string(needsScatter) + dtypeSrc + dtypeDst; static std::unordered_map> _libCache; auto it = _libCache.find(key); if (it != _libCache.end()) { return it->second; } - NSError *error = nil; - MTLCompileOptions *options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion: MTLLanguageVersion2_3]; - auto gatherScatterLib = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE : GATHER_OPS_TEMPLATE, dtypeSrc, dtypeDst).c_str()] - options:options - error:&error]; - TORCH_CHECK(gatherScatterLib != nil && error == nil, "Failed to compile gather-scatter library, error: ", [[error description] UTF8String]); + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + auto gatherScatterLib = + [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(needsScatter ? SCATTER_OPS_TEMPLATE + : GATHER_OPS_TEMPLATE, + dtypeSrc, + dtypeDst) + .c_str()] + options:options + error:&error]; + TORCH_CHECK(gatherScatterLib != nil && error == nil, + "Failed to compile gather-scatter library, error: ", + [[error description] UTF8String]); _libCache[key] = gatherScatterLib; return gatherScatterLib; } @@ -790,15 +786,16 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) static std::unordered_map> _mtlPipelineCache; auto it = _mtlPipelineCache.find(key); if (it != _mtlPipelineCache.end()) { - return it->second; + return it->second; } - NSError *error = nil; + NSError* error = nil; id library = compileGatherScatterOpsLibrary(device, dtypeSrc, dtypeDst, needsScatter); id func = [library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; TORCH_CHECK(func, "Failed to load the Metal Shader function: ", kernel); id pso = [device newComputePipelineStateWithFunction:func error:&error]; - TORCH_CHECK(pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); + TORCH_CHECK( + pso != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]); _mtlPipelineCache[key] = pso; return pso; } @@ -814,8 +811,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { } if (src.dim() > 5) { - ViewCachedGraph* cachedGraph = createViewGraph(src, dst, src.sizes(), src.strides(), - src.storage_offset(), /*needsScatter*/ false); + ViewCachedGraph* cachedGraph = + createViewGraph(src, dst, src.sizes(), src.strides(), src.storage_offset(), /*needsScatter*/ false); return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false); } @@ -824,7 +821,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { uint32_t numThreads = output.numel(); MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^(){ + dispatch_sync(mpsStream->queue(), ^() { id computeEncoder = [mpsStream->commandBuffer() computeCommandEncoder]; std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/false); id gatherPSO = getPipelineState(MPSDevice::getInstance()->device(), @@ -846,7 +843,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { } } - [computeEncoder setComputePipelineState: gatherPSO]; + [computeEncoder setComputePipelineState:gatherPSO]; [computeEncoder setBuffer:getMTLBufferStorage(src) offset:src.storage_offset() * src.element_size() atIndex:0]; [computeEncoder setBuffer:outputBuffer offset:outputStorageOffset atIndex:1]; [computeEncoder setBytes:&src_sizes[0] length:sizeof(uint32_t) * kernel_size atIndex:2]; @@ -856,7 +853,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); NSUInteger threadsPerThreadgroup_ = gatherPSO.maxTotalThreadsPerThreadgroup; if (threadsPerThreadgroup_ > numThreads) { - threadsPerThreadgroup_ = numThreads; + threadsPerThreadgroup_ = numThreads; } MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1); @@ -868,11 +865,14 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { return (dst.has_storage()) ? dst : output; } -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output){ +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) { if (output.dim() > 5) { - ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output, - src, output.sizes(), output.strides(), - output.storage_offset(), /*needsScatter*/ true); + ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output, + src, + output.sizes(), + output.strides(), + output.storage_offset(), + /*needsScatter*/ true); return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true); } if (src.numel() == 0 || output.numel() == 0) { @@ -884,11 +884,12 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { uint32_t numThreads = src.numel(); int64_t outputStorageOffset = output.storage_offset() * output.element_size(); MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^(){ + dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id commandBuffer = mpsStream->commandBuffer(); id computeEncoder = [commandBuffer computeCommandEncoder]; - std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true); + std::string functionName = + getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true); id scatterPSO = getPipelineState(MPSDevice::getInstance()->device(), functionName, getGatherScatterScalarType(src), @@ -908,7 +909,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { } } - [computeEncoder setComputePipelineState: scatterPSO]; + [computeEncoder setComputePipelineState:scatterPSO]; [computeEncoder setBuffer:sourceBuffer offset:src.storage_offset() * src.element_size() atIndex:0]; [computeEncoder setBuffer:outputBuffer offset:outputStorageOffset atIndex:1]; [computeEncoder setBytes:&output_sizes[0] length:sizeof(uint32_t) * kernel_size atIndex:2]; @@ -934,16 +935,21 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { } // namespace mps // implementation of as_strided() op -Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional storage_offset_) { +Tensor as_strided_tensorimpl_mps(const Tensor& self, + IntArrayRef size, + IntArrayRef stride, + c10::optional storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); - auto result = detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); + auto result = + detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); setStrided(result, size, stride, storage_offset); // creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called. // In as_strided, we just update the base shape of the buffer in order to retrieve it later // when we create/run the view graph. IntArrayRef base_shape = mps::updateTensorBaseShape(self); - TORCH_INTERNAL_ASSERT(base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data()); + TORCH_INTERNAL_ASSERT( + base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data()); return result; } diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h index d488f5b6e71dfe..23c8e5a7bae9bd 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h @@ -10,8 +10,7 @@ NS_ASSUME_NONNULL_BEGIN + (NSString*)cacheDirectory; -+ (BOOL)compileModel:(const std::string&)modelSpecs - modelID:(const std::string&)modelID; ++ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID; + (nullable MLModel*)loadModel:(const std::string&)modelID backend:(const std::string&)backend