diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 195dd058aa64d..ae8798a5d2c4f 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -202,6 +202,10 @@ class RocmComputeCapability { bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; } + bool gfx12_rx8600() const { return gfx_version() == "gfx1200"; } + + bool gfx12_rx8800() const { return gfx_version() == "gfx1201"; } + bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } @@ -229,7 +233,7 @@ class RocmComputeCapability { bool has_hipblaslt() const { return gfx9_mi200_or_later(); } - bool has_fp8_support() const { return gfx9_mi300(); } + bool has_fp8_support() const { return gfx9_mi300() || gfx12_rx8600() || gfx12_rx8800(); } RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; @@ -251,7 +255,9 @@ class RocmComputeCapability { "gfx90a", // MI200 "gfx940", "gfx941", "gfx942", // MI300 "gfx1030", // RX68xx / RX69xx - "gfx1100" // RX7900 + "gfx1100", // RX7900 + "gfx1200", // RX8600 + "gfx1201", // RX8800 }; };