diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 3b5267bf305d8..3cdd311893c2b 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -204,16 +204,19 @@ class RocmComputeCapability { return absl::c_count(kList, gfx_version()) != 0; } - bool navi21() const { return gfx_version() == "gfx1030"; } + bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; } - bool navi31() const { return gfx_version() == "gfx1100"; } + bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; } + + bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; } bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } bool has_fast_fp16_support() const { - return gfx9_mi100_or_later() || navi21() || navi31(); + return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() || + gfx11_rx7900(); } bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } @@ -250,8 +253,8 @@ class RocmComputeCapability { "gfx908", // MI100 "gfx90a", // MI200 "gfx940", "gfx941", "gfx942", // MI300 - "gfx1030", // Navi21 - "gfx1100" // Navi31 + "gfx1030", // RX68xx / RX69xx + "gfx1100" // RX7900 }; }; diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index cb21af93e6da5..c822daf40a462 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -2014,12 +2014,16 @@ static absl::StatusOr GetSimpleAttribute(hipDevice_t device, const uint64_t RESERVED_GFX908 = 1048576 * 512; const uint64_t RESERVED_GFX9_X = 1048576 * 1024; const uint64_t RESERVED_GFX10_X = 1048576 * 512; - if (compute_capability.gfx_version() == "gfx908") { + const uint64_t RESERVED_GFX11_X = 1048576 * 512; + if (compute_capability.gfx9_mi100()) { *reserve = RESERVED_GFX908; } else if (compute_capability.gfx9_mi200_or_later()) { *reserve = RESERVED_GFX9_X; - } else if (compute_capability.navi21() || compute_capability.navi31()) { + } else if (compute_capability.gfx10_rx68xx() || + compute_capability.gfx10_rx69xx()) { *reserve = RESERVED_GFX10_X; + } else if (compute_capability.gfx11_rx7900()) { + *reserve = RESERVED_GFX11_X; } return true;