-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Swopper050 <[email protected]>
- Loading branch information
1 parent
c02923d
commit fb70feb
Showing
5 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package opset13 | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
const ( | ||
MinFlattenInputs = 1 | ||
MaxFlattenInputs = 1 | ||
) | ||
|
||
// Flatten represents the ONNX flatten operator. | ||
type Flatten struct { | ||
axis int | ||
} | ||
|
||
// newFlatten creates a new flatten operator. | ||
func newFlatten() ops.Operator { | ||
return &Flatten{ | ||
axis: 1, | ||
} | ||
} | ||
|
||
// Init initializes the flatten operator. | ||
func (f *Flatten) Init(n *onnx.NodeProto) error { | ||
for _, attr := range n.GetAttribute() { | ||
switch attr.GetName() { | ||
case "axis": | ||
f.axis = int(attr.GetI()) | ||
default: | ||
return ops.ErrInvalidAttribute(attr.GetName(), f) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Apply applies the flatten operator. | ||
func (f *Flatten) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
inputShape := inputs[0].Shape() | ||
rank := len(inputShape) | ||
|
||
axis := f.axis | ||
if axis < 0 { | ||
axis = rank + axis | ||
} | ||
|
||
out, ok := inputs[0].Clone().(tensor.Tensor) | ||
if !ok { | ||
return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) | ||
} | ||
|
||
var err error | ||
// In the special case where axis is 0, we reshape the tensor to shape | ||
// (1, <n_elements>). This is ONNX defined behaviour. | ||
if axis == 0 { | ||
err = out.Reshape(1, ops.NElements(inputShape...)) | ||
} else { | ||
err = out.Reshape(ops.NElements(inputShape[:axis]...), ops.NElements(inputShape[axis:]...)) | ||
} | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return []tensor.Tensor{out}, nil | ||
} | ||
|
||
// ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
func (f *Flatten) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return ops.ValidateInputs(f, inputs) | ||
} | ||
|
||
// GetMinInputs returns the minimum number of input tensors this operator expects. | ||
func (f *Flatten) GetMinInputs() int { | ||
return MinFlattenInputs | ||
} | ||
|
||
// GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
func (f *Flatten) GetMaxInputs() int { | ||
return MaxFlattenInputs | ||
} | ||
|
||
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
// for the corresponding input tensor. | ||
func (f *Flatten) GetInputTypeConstraints() [][]tensor.Dtype { | ||
return [][]tensor.Dtype{ops.AllTypes} | ||
} | ||
|
||
// String implements the stringer interface, and can be used to format errors or messages. | ||
func (f *Flatten) String() string { | ||
return "flatten operator" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
package opset13 | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"github.com/stretchr/testify/assert" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
func TestFlattenInit(t *testing.T) { | ||
f := &Flatten{} | ||
|
||
err := f.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 2}}}) | ||
assert.Nil(t, err) | ||
|
||
assert.Equal(t, 2, f.axis) | ||
} | ||
|
||
func TestFlatten(t *testing.T) { | ||
tests := []struct { | ||
flatten *Flatten | ||
backing []float32 | ||
shape []int | ||
expectedShape tensor.Shape | ||
}{ | ||
{ | ||
&Flatten{}, | ||
[]float32{0, 1, 2, 3}, | ||
[]int{2, 2}, | ||
[]int{1, 4}, | ||
}, | ||
{ | ||
&Flatten{}, | ||
[]float32{0, 1, 2, 3, 4, 5}, | ||
[]int{2, 3}, | ||
[]int{1, 6}, | ||
}, | ||
{ | ||
&Flatten{axis: 1}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7}, | ||
[]int{2, 2, 2}, | ||
[]int{2, 4}, | ||
}, | ||
{ | ||
&Flatten{axis: 2}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7}, | ||
[]int{2, 2, 2}, | ||
[]int{4, 2}, | ||
}, | ||
{ | ||
&Flatten{axis: -1}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7}, | ||
[]int{2, 2, 2}, | ||
[]int{4, 2}, | ||
}, | ||
{ | ||
&Flatten{axis: -2}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7}, | ||
[]int{2, 2, 2}, | ||
[]int{2, 4}, | ||
}, | ||
{ | ||
&Flatten{axis: -3}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, | ||
[]int{3, 2, 3}, | ||
[]int{1, 18}, | ||
}, | ||
{ | ||
&Flatten{axis: 2}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, | ||
[]int{3, 2, 3}, | ||
[]int{6, 3}, | ||
}, | ||
{ | ||
&Flatten{axis: 1}, | ||
[]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, | ||
[]int{3, 2, 3}, | ||
[]int{3, 6}, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
inputs := []tensor.Tensor{ | ||
ops.TensorWithBackingFixture(test.backing, test.shape...), | ||
} | ||
|
||
res, err := test.flatten.Apply(inputs) | ||
assert.Nil(t, err) | ||
|
||
assert.Equal(t, test.expectedShape, res[0].Shape()) | ||
} | ||
} | ||
|
||
func TestInputValidationFlatten(t *testing.T) { | ||
tests := []struct { | ||
inputs []tensor.Tensor | ||
err error | ||
}{ | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]uint32{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]uint64{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]int32{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]int64{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]float64{1, 2}, 2), | ||
}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
ops.TensorWithBackingFixture([]float32{1, 2}, 2), | ||
}, | ||
ops.ErrInvalidInputCount(2, &Flatten{}), | ||
}, | ||
{ | ||
[]tensor.Tensor{ | ||
ops.TensorWithBackingFixture([]int{1, 2}, 2), | ||
}, | ||
ops.ErrInvalidInputType(0, "int", &Flatten{}), | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
flatten := &Flatten{} | ||
validated, err := flatten.ValidateInputs(test.inputs) | ||
|
||
assert.Equal(t, test.err, err) | ||
|
||
if test.err == nil { | ||
assert.Equal(t, test.inputs, validated) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters