From 788300cc2aa096d8d5c1e7fbfc87e5439a338251 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Thu, 23 Mar 2023 01:41:04 +0000 Subject: [PATCH] [cudnn] Support v8 API in fbcode (#96512) Summary: It turns out we never turn on cudnn v8 API which blocks bf16 conv. Enable the new v8 API Test Plan: buck run mode/dev-nosan scripts/xdwang/example:fc_pytorch Reviewed By: ngimel Differential Revision: D43784279 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96512 Approved by: https://github.com/malfet --- aten/src/ATen/native/cudnn/Conv_v8.cpp | 14 +++++++------- aten/src/ATen/native/cudnn/Macros.h | 2 +- defs.bzl | 1 + 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index ba5c7a28cc91e..cb07ef335e56f 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -45,7 +45,7 @@ namespace at { namespace native { namespace { // TODO: remove duplicate code in Conv_v7.cpp -constexpr size_t operator "" _TiB(unsigned long long n) { +constexpr int64_t operator "" _TiB(unsigned long long n) { return size_t(n) << 40; } @@ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso } } -size_t get_available_workspace() { +int64_t get_available_workspace() { int device; C10_CUDA_CHECK(cudaGetDevice(&device)); size_t max_block_size = 0; c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size); - return max_block_size; + return static_cast(max_block_size); } static nlohmann::json errata_json_handle; @@ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera return plan_errata_exception(handle, plan.getTag()); }; auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function); - size_t max_block_size = get_available_workspace(); - size_t max_workspace_size = 0u; + int64_t max_block_size = get_available_workspace(); + int64_t max_workspace_size = 0; std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) { - size_t curr_workspace_size = plan.getWorkspaceSize(); + int64_t curr_workspace_size = plan.getWorkspaceSize(); if (curr_workspace_size <= max_block_size) { if (curr_workspace_size > max_workspace_size) { max_workspace_size = plan.getWorkspaceSize(); @@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera if (remove_invalid) { cudnn_frontend::executionPlans_t new_valid_plans; for (auto &plan : valid_plans) { - if (static_cast(plan.getWorkspaceSize()) <= max_workspace_size) { + if (plan.getWorkspaceSize() <= max_workspace_size) { new_valid_plans.emplace_back(std::move(plan)); } } diff --git a/aten/src/ATen/native/cudnn/Macros.h b/aten/src/ATen/native/cudnn/Macros.h index fdc6552432818..d38a6c6c695fb 100644 --- a/aten/src/ATen/native/cudnn/Macros.h +++ b/aten/src/ATen/native/cudnn/Macros.h @@ -5,7 +5,7 @@ // Note: The version below should not actually be 8000. Instead, it should // be whatever version of cuDNN that v8 API work with PyTorch correctly. // The version is set to 8000 today for convenience of debugging. -#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300 #define HAS_CUDNN_V8() true #else #define HAS_CUDNN_V8() false diff --git a/defs.bzl b/defs.bzl index 00cf0fa8f0610..e8838b8aa2cbb 100644 --- a/defs.bzl +++ b/defs.bzl @@ -34,6 +34,7 @@ default_compiler_flags = [ "-DTH_INDEX_BASE=0", "-DMAGMA_V2", "-DNO_CUDNN_DESTROY_HANDLE", + "-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api "-DUSE_FBGEMM", "-DUSE_QNNPACK", "-DUSE_PYTORCH_QNNPACK",