Skip to content

Commit

Permalink
WIP on operator migration
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 1, 2024
1 parent de51332 commit 5406257
Show file tree
Hide file tree
Showing 31 changed files with 1,820 additions and 364 deletions.
72 changes: 72 additions & 0 deletions ops/constant/constant_1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package constant

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Constant1 represents the ONNX constant operator.
type Constant1 struct {
value tensor.Tensor
}

// newConstant1 creates a new constant operator.
func NewConstant1() ops.Operator {
return &Constant1{}
}

// Init initializes the constant operator.
func (c *Constant1) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
if len(attributes) != 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), c)
}

attr := attributes[0]

switch attr.GetName() {
case "value":
t, err := onnx.TensorFromProto(attr.GetT())
if err != nil {
return err
}

c.value = t
default:
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
}

return nil
}

// Apply applies the constant operator.
func (c *Constant1) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) {
return []tensor.Tensor{c.value}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Constant1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Constant1) GetMinInputs() int {
return 0
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Constant1) GetMaxInputs() int {
return 0
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Constant1) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Constant1) String() string {
return "constant1 operator"
}
75 changes: 75 additions & 0 deletions ops/constant/constant_11.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package constant

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Constant11 represents the ONNX constant operator.
type Constant11 struct {
value tensor.Tensor
}

// newConstant11 creates a new constant operator.
func NewConstant11() ops.Operator {
return &Constant11{}
}

// Init initializes the constant operator. It supports all constant types except
// `sparse_value`.
func (c *Constant11) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
if len(attributes) != 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), c)
}

attr := attributes[0]

switch attr.GetName() {
case "sparse_value":
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
case "value":
t, err := onnx.TensorFromProto(attr.GetT())
if err != nil {
return err
}

c.value = t
default:
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
}

return nil
}

// Apply applies the constant operator.
func (c *Constant11) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) {
return []tensor.Tensor{c.value}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Constant11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Constant11) GetMinInputs() int {
return 0
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Constant11) GetMaxInputs() int {
return 0
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Constant11) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Constant11) String() string {
return "constant11 operator"
}
28 changes: 14 additions & 14 deletions ops/opset13/constant.go → ops/constant/constant_12.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
package opset13
package constant

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Constant represents the ONNX constant operator.
type Constant struct {
// Constant12 represents the ONNX constant operator.
type Constant12 struct {
value tensor.Tensor
}

// newConstant creates a new constant operator.
func newConstant() ops.Operator {
return &Constant{}
// newConstant12 creates a new constant operator.
func NewConstant12() ops.Operator {
return &Constant12{}
}

// Init initializes the constant operator. It supports all constant types except
// `sparse_value`, `value_string`, and `value_strings`.
func (c *Constant) Init(n *onnx.NodeProto) error {
func (c *Constant12) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
if len(attributes) != 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), c)
Expand Down Expand Up @@ -54,32 +54,32 @@ func (c *Constant) Init(n *onnx.NodeProto) error {
}

// Apply applies the constant operator.
func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) {
func (c *Constant12) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) {
return []tensor.Tensor{c.value}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Constant) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
func (c *Constant12) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Constant) GetMinInputs() int {
func (c *Constant12) GetMinInputs() int {
return 0
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Constant) GetMaxInputs() int {
func (c *Constant12) GetMaxInputs() int {
return 0
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Constant) GetInputTypeConstraints() [][]tensor.Dtype {
func (c *Constant12) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Constant) String() string {
return "constant operator"
func (c *Constant12) String() string {
return "constant12 operator"
}
85 changes: 85 additions & 0 deletions ops/constant/constant_13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package constant

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Constant13 represents the ONNX constant operator.
type Constant13 struct {
value tensor.Tensor
}

// newConstant13 creates a new constant operator.
func NewConstant13() ops.Operator {
return &Constant13{}
}

// Init initializes the constant operator. It supports all constant types except
// `sparse_value`, `value_string`, and `value_strings`.
func (c *Constant13) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()
if len(attributes) != 1 {
return ops.ErrInvalidAttributeCount(1, len(attributes), c)
}

attr := attributes[0]

switch attr.GetName() {
case "sparse_value", "value_string", "value_strings":
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
case "value":
t, err := onnx.TensorFromProto(attr.GetT())
if err != nil {
return err
}

c.value = t
case "value_float":
c.value = tensor.New(tensor.FromScalar(attr.GetF()))
case "value_floats":
floats := attr.GetFloats()
c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats))
case "value_int":
c.value = tensor.New(tensor.FromScalar(attr.GetI()))
case "value_ints":
ints := attr.GetInts()
c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints))
default:
return ops.ErrUnsupportedAttribute(attr.GetName(), c)
}

return nil
}

// Apply applies the constant operator.
func (c *Constant13) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) {
return []tensor.Tensor{c.value}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Constant13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Constant13) GetMinInputs() int {
return 0
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Constant13) GetMaxInputs() int {
return 0
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Constant13) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Constant13) String() string {
return "constant13 operator"
}
Loading

0 comments on commit 5406257

Please sign in to comment.