Skip to content

Commit

Permalink
Add fp8 and int4 types in supported list for Onnxruntime EP
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Nov 29, 2024
1 parent 061c493 commit c39eb4a
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,17 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
switch (type_proto->tensor_type().elem_type()) {
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32:
Expand All @@ -261,6 +267,21 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
mgx_type = migraphx_shape_double_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ:
mgx_type = migraphx_shape_fp8e4m3fnuz_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN:
mgx_type = migraphx_shape_fp8e4m3fn_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2:
mgx_type = migraphx_shape_fp8e5m2_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ:
mgx_type = migraphx_shape_fp8e5m2nuz_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
mgx_type = migraphx_shape_int4_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
mgx_type = migraphx_shape_int8_type;
break;
Expand All @@ -273,6 +294,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
mgx_type = migraphx_shape_int64_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4:
mgx_type = migraphx_shape_uint4_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
mgx_type = migraphx_shape_uint8_type;
break;
Expand Down

0 comments on commit c39eb4a

Please sign in to comment.