diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index ba5c7a28cc91e2..cb07ef335e56f0 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -45,7 +45,7 @@ namespace at { namespace native { namespace { // TODO: remove duplicate code in Conv_v7.cpp -constexpr size_t operator "" _TiB(unsigned long long n) { +constexpr int64_t operator "" _TiB(unsigned long long n) { return size_t(n) << 40; } @@ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso } } -size_t get_available_workspace() { +int64_t get_available_workspace() { int device; C10_CUDA_CHECK(cudaGetDevice(&device)); size_t max_block_size = 0; c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size); - return max_block_size; + return static_cast(max_block_size); } static nlohmann::json errata_json_handle; @@ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera return plan_errata_exception(handle, plan.getTag()); }; auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function); - size_t max_block_size = get_available_workspace(); - size_t max_workspace_size = 0u; + int64_t max_block_size = get_available_workspace(); + int64_t max_workspace_size = 0; std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) { - size_t curr_workspace_size = plan.getWorkspaceSize(); + int64_t curr_workspace_size = plan.getWorkspaceSize(); if (curr_workspace_size <= max_block_size) { if (curr_workspace_size > max_workspace_size) { max_workspace_size = plan.getWorkspaceSize(); @@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera if (remove_invalid) { cudnn_frontend::executionPlans_t new_valid_plans; for (auto &plan : valid_plans) { - if (static_cast(plan.getWorkspaceSize()) <= max_workspace_size) { + if (plan.getWorkspaceSize() <= max_workspace_size) { new_valid_plans.emplace_back(std::move(plan)); } } diff --git a/aten/src/ATen/native/cudnn/Macros.h b/aten/src/ATen/native/cudnn/Macros.h index fdc6552432818c..d38a6c6c695fb6 100644 --- a/aten/src/ATen/native/cudnn/Macros.h +++ b/aten/src/ATen/native/cudnn/Macros.h @@ -5,7 +5,7 @@ // Note: The version below should not actually be 8000. Instead, it should // be whatever version of cuDNN that v8 API work with PyTorch correctly. // The version is set to 8000 today for convenience of debugging. -#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300 #define HAS_CUDNN_V8() true #else #define HAS_CUDNN_V8() false diff --git a/defs.bzl b/defs.bzl index 00cf0fa8f06109..e8838b8aa2cbb8 100644 --- a/defs.bzl +++ b/defs.bzl @@ -34,6 +34,7 @@ default_compiler_flags = [ "-DTH_INDEX_BASE=0", "-DMAGMA_V2", "-DNO_CUDNN_DESTROY_HANDLE", + "-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api "-DUSE_FBGEMM", "-DUSE_QNNPACK", "-DUSE_PYTORCH_QNNPACK",