Skip to content

Commit

Permalink
Refactored Transpose operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 3, 2024
1 parent e98fc1b commit 67a7796
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 176 deletions.
59 changes: 59 additions & 0 deletions ops/transpose/transpose.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package transpose

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var transposeTypeConstraint = [][]tensor.Dtype{ops.AllTypes}

// Transpose represents the ONNX transpose operator.
type Transpose struct {
ops.BaseOperator

perm []int
}

// newTranspose creates a new transpose operator.
func newTranspose(version int, typeConstraint [][]tensor.Dtype) *Transpose {
return &Transpose{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraint,
"transpose",
),
}
}

// Init initializes the transpose operator.
func (t *Transpose) Init(n *onnx.NodeProto) error {
attributes := n.GetAttribute()

if len(attributes) == 1 {
attr := attributes[0]

if attr.GetName() != "perm" {
return ops.ErrInvalidAttribute(attr.GetName(), t)
}

attrPerm := attr.GetInts()
for _, val := range attrPerm {
t.perm = append(t.perm, int(val))
}
}

return nil
}

// Apply applies the transpose operator.
func (t *Transpose) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := tensor.Transpose(inputs[0], t.perm...)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}
80 changes: 0 additions & 80 deletions ops/transpose/transpose_1.go

This file was deleted.

78 changes: 0 additions & 78 deletions ops/transpose/transpose_13.go

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,37 @@ import (
"gorgonia.org/tensor"
)

func TestTranspose13Init(t *testing.T) {
trans := &Transpose13{}
err := trans.Init(Transpose13OnnxNodeProtoFixture())
func TestTransposeInit(t *testing.T) {
trans := &Transpose{}
err := trans.Init(TransposeOnnxNodeProtoFixture())

assert.Nil(t, err)
assert.Equal(t, []int{1, 0}, trans.perm)
}

func TestTranspose13InitFailWrongAttribute(t *testing.T) {
trans := &Transpose13{}
func TestTransposeInitFailWrongAttribute(t *testing.T) {
trans := &Transpose{}
err := trans.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}}})

expected := ops.ErrInvalidAttribute("unknownAttribute", trans)
assert.Equal(t, expected, err)
}

func TestTranspose13(t *testing.T) {
func TestTranspose(t *testing.T) {
tests := []struct {
trans *Transpose13
trans *Transpose
shape []int
expectedShape tensor.Shape
expectedBacking []float32
}{
{
&Transpose13{},
&Transpose{},
[]int{3, 2},
[]int{2, 3},
[]float32{0, 2, 4, 1, 3, 5},
},
{
&Transpose13{perm: []int{0, 2, 1}},
&Transpose{perm: []int{0, 2, 1}},
[]int{1, 2, 3},
[]int{1, 3, 2},
[]float32{0, 3, 1, 4, 2, 5},
Expand All @@ -57,35 +57,56 @@ func TestTranspose13(t *testing.T) {
}
}

func TestInputValidationTranspose13(t *testing.T) {
func TestInputValidationTranspose(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
version int64
inputs []tensor.Tensor
err error
}{
{
1,
[]tensor.Tensor{ops.TensorWithBackingFixture([]uint32{1, 2}, 2)},
nil,
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]uint32{1, 2}, 2)},
nil,
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)},
nil,
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)},
nil,
},
{
1,
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Transpose13{}),
ops.ErrInvalidInputCount(0, transpose1BaseOpFixture()),
},
{
13,
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, transpose13BaseOpFixture()),
},
{
1,
[]tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)},
ops.ErrInvalidInputType(0, "int", transpose1BaseOpFixture()),
},
{
13,
[]tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)},
ops.ErrInvalidInputType(0, "int", &Transpose13{}),
ops.ErrInvalidInputType(0, "int", transpose13BaseOpFixture()),
},
}

for _, test := range tests {
transpose := &Transpose13{}
transpose := TransposeVersions[test.version]()
validated, err := transpose.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand All @@ -96,10 +117,20 @@ func TestInputValidationTranspose13(t *testing.T) {
}
}

func Transpose13OnnxNodeProtoFixture() *onnx.NodeProto {
func TransposeOnnxNodeProtoFixture() *onnx.NodeProto {
return &onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "perm", Ints: []int64{1, 0}},
},
}
}

func transpose1BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(1, 1, 1, transposeTypeConstraint, "transpose")

}

func transpose13BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(13, 1, 1, transposeTypeConstraint, "transpose")

}
4 changes: 2 additions & 2 deletions ops/transpose/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ package transpose
import "github.com/advancedclimatesystems/gonnx/ops"

var TransposeVersions = ops.OperatorVersions{
1: newTranspose1, // Only bfloat16 type differs
13: newTranspose13,
1: ops.NewOperatorConstructor(newTranspose(1, transposeTypeConstraint)),
13: ops.NewOperatorConstructor(newTranspose(13, transposeTypeConstraint)),
}

0 comments on commit 67a7796

Please sign in to comment.