-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
de51332
commit 5406257
Showing
31 changed files
with
1,820 additions
and
364 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package constant | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
// Constant1 represents the ONNX constant operator. | ||
type Constant1 struct { | ||
value tensor.Tensor | ||
} | ||
|
||
// newConstant1 creates a new constant operator. | ||
func NewConstant1() ops.Operator { | ||
return &Constant1{} | ||
} | ||
|
||
// Init initializes the constant operator. | ||
func (c *Constant1) Init(n *onnx.NodeProto) error { | ||
attributes := n.GetAttribute() | ||
if len(attributes) != 1 { | ||
return ops.ErrInvalidAttributeCount(1, len(attributes), c) | ||
} | ||
|
||
attr := attributes[0] | ||
|
||
switch attr.GetName() { | ||
case "value": | ||
t, err := onnx.TensorFromProto(attr.GetT()) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
c.value = t | ||
default: | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), c) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Apply applies the constant operator. | ||
func (c *Constant1) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return []tensor.Tensor{c.value}, nil | ||
} | ||
|
||
// ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
func (c *Constant1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return ops.ValidateInputs(c, inputs) | ||
} | ||
|
||
// GetMinInputs returns the minimum number of input tensors this operator expects. | ||
func (c *Constant1) GetMinInputs() int { | ||
return 0 | ||
} | ||
|
||
// GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
func (c *Constant1) GetMaxInputs() int { | ||
return 0 | ||
} | ||
|
||
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
// for the corresponding input tensor. | ||
func (c *Constant1) GetInputTypeConstraints() [][]tensor.Dtype { | ||
return [][]tensor.Dtype{} | ||
} | ||
|
||
// String implements the stringer interface, and can be used to format errors or messages. | ||
func (c *Constant1) String() string { | ||
return "constant1 operator" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package constant | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
// Constant11 represents the ONNX constant operator. | ||
type Constant11 struct { | ||
value tensor.Tensor | ||
} | ||
|
||
// newConstant11 creates a new constant operator. | ||
func NewConstant11() ops.Operator { | ||
return &Constant11{} | ||
} | ||
|
||
// Init initializes the constant operator. It supports all constant types except | ||
// `sparse_value`. | ||
func (c *Constant11) Init(n *onnx.NodeProto) error { | ||
attributes := n.GetAttribute() | ||
if len(attributes) != 1 { | ||
return ops.ErrInvalidAttributeCount(1, len(attributes), c) | ||
} | ||
|
||
attr := attributes[0] | ||
|
||
switch attr.GetName() { | ||
case "sparse_value": | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), c) | ||
case "value": | ||
t, err := onnx.TensorFromProto(attr.GetT()) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
c.value = t | ||
default: | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), c) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Apply applies the constant operator. | ||
func (c *Constant11) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return []tensor.Tensor{c.value}, nil | ||
} | ||
|
||
// ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
func (c *Constant11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return ops.ValidateInputs(c, inputs) | ||
} | ||
|
||
// GetMinInputs returns the minimum number of input tensors this operator expects. | ||
func (c *Constant11) GetMinInputs() int { | ||
return 0 | ||
} | ||
|
||
// GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
func (c *Constant11) GetMaxInputs() int { | ||
return 0 | ||
} | ||
|
||
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
// for the corresponding input tensor. | ||
func (c *Constant11) GetInputTypeConstraints() [][]tensor.Dtype { | ||
return [][]tensor.Dtype{} | ||
} | ||
|
||
// String implements the stringer interface, and can be used to format errors or messages. | ||
func (c *Constant11) String() string { | ||
return "constant11 operator" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package constant | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
// Constant13 represents the ONNX constant operator. | ||
type Constant13 struct { | ||
value tensor.Tensor | ||
} | ||
|
||
// newConstant13 creates a new constant operator. | ||
func NewConstant13() ops.Operator { | ||
return &Constant13{} | ||
} | ||
|
||
// Init initializes the constant operator. It supports all constant types except | ||
// `sparse_value`, `value_string`, and `value_strings`. | ||
func (c *Constant13) Init(n *onnx.NodeProto) error { | ||
attributes := n.GetAttribute() | ||
if len(attributes) != 1 { | ||
return ops.ErrInvalidAttributeCount(1, len(attributes), c) | ||
} | ||
|
||
attr := attributes[0] | ||
|
||
switch attr.GetName() { | ||
case "sparse_value", "value_string", "value_strings": | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), c) | ||
case "value": | ||
t, err := onnx.TensorFromProto(attr.GetT()) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
c.value = t | ||
case "value_float": | ||
c.value = tensor.New(tensor.FromScalar(attr.GetF())) | ||
case "value_floats": | ||
floats := attr.GetFloats() | ||
c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) | ||
case "value_int": | ||
c.value = tensor.New(tensor.FromScalar(attr.GetI())) | ||
case "value_ints": | ||
ints := attr.GetInts() | ||
c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) | ||
default: | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), c) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Apply applies the constant operator. | ||
func (c *Constant13) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return []tensor.Tensor{c.value}, nil | ||
} | ||
|
||
// ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
func (c *Constant13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return ops.ValidateInputs(c, inputs) | ||
} | ||
|
||
// GetMinInputs returns the minimum number of input tensors this operator expects. | ||
func (c *Constant13) GetMinInputs() int { | ||
return 0 | ||
} | ||
|
||
// GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
func (c *Constant13) GetMaxInputs() int { | ||
return 0 | ||
} | ||
|
||
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
// for the corresponding input tensor. | ||
func (c *Constant13) GetInputTypeConstraints() [][]tensor.Dtype { | ||
return [][]tensor.Dtype{} | ||
} | ||
|
||
// String implements the stringer interface, and can be used to format errors or messages. | ||
func (c *Constant13) String() string { | ||
return "constant13 operator" | ||
} |
Oops, something went wrong.