Skip to content

Commit

Permalink
Add fp8 and int4 types in supported list for Onnxruntime EP
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Nov 30, 2024
1 parent 061c493 commit 531f35c
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
switch (type_proto->tensor_type().elem_type()) {
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16:
Expand Down Expand Up @@ -261,6 +265,18 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
mgx_type = migraphx_shape_double_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ:
mgx_type = migraphx_shape_fp8e4m3fnuz_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN:
mgx_type = migraphx_shape_fp8e4m3fn_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2:
mgx_type = migraphx_shape_fp8e5m2_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ:
mgx_type = migraphx_shape_fp8e5m2fnuz_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
mgx_type = migraphx_shape_int8_type;
break;
Expand Down

0 comments on commit 531f35c

Please sign in to comment.