Skip to content

Commit

Permalink
implement gridsample 20 (microsoft#17744)
Browse files Browse the repository at this point in the history
  • Loading branch information
liqunfu authored Nov 7, 2023
1 parent 83c0275 commit 6127dd1
Show file tree
Hide file tree
Showing 9 changed files with 1,383 additions and 158 deletions.
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(bool)|
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(bool)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(double), tensor(float)|
|||[16, 19]|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|HammingWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|HannWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float)|
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 18, If);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 19, float, GridSample);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where);
Expand Down Expand Up @@ -960,6 +960,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh

// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
Expand Down Expand Up @@ -2183,8 +2185,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double,
RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float,
GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 19, float,
GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where)>,
Expand Down Expand Up @@ -2401,6 +2403,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {

// Opset 20
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN)>,
Expand Down
Loading

0 comments on commit 6127dd1

Please sign in to comment.