Skip to content

Commit

Permalink
Refactored unsqueeze operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Dec 3, 2024
1 parent b5ce7d9 commit e98fc1b
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 136 deletions.
61 changes: 25 additions & 36 deletions ops/unsqueeze/unsqueeze_13.go → ops/unsqueeze/unsqueeze.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,41 @@ import (
"gorgonia.org/tensor"
)

var unsqueezeTypeConstraints = [][]tensor.Dtype{
ops.AllTypes,
{tensor.Int64},
}

const (
MinUnsqueeze13Inputs = 2
MaxUnsqueeze13Inputs = 2
MinUnsqueezeInputs = 2
MaxUnsqueezeInputs = 2
)

// Unsqueeze13 represents the ONNX unsqueeze operator.
type Unsqueeze13 struct{}
// Unsqueeze represents the ONNX unsqueeze operator.
type Unsqueeze struct {
ops.BaseOperator
}

// newUnsqueeze13 creates a new unsqueeze operator.
func newUnsqueeze13() ops.Operator {
return &Unsqueeze13{}
// newUnsqueeze creates a new unsqueeze operator.
func newUnsqueeze(version int, typeConstraint [][]tensor.Dtype) *Unsqueeze {
return &Unsqueeze{
BaseOperator: ops.NewBaseOperator(
version,
MinUnsqueezeInputs,
MaxUnsqueezeInputs,
typeConstraint,
"unsqueeze",
),
}
}

// Init initializes the unsqueeze operator.
func (u *Unsqueeze13) Init(*onnx.NodeProto) error {
func (u *Unsqueeze) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the unsqueeze operator.
func (u *Unsqueeze13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
dataShape := inputs[0].Shape()

axes, err := ops.AnyToIntSlice(inputs[1].Data())
Expand All @@ -48,7 +63,7 @@ func (u *Unsqueeze13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
sort.Ints(axes)

if ops.HasDuplicates(axes) {
return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u)
return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator)
}

newShape := insertOnes(dataShape, axes)
Expand All @@ -63,32 +78,6 @@ func (u *Unsqueeze13) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return []tensor.Tensor{out}, err
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (u *Unsqueeze13) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(u, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (u *Unsqueeze13) GetMinInputs() int {
return MinUnsqueeze13Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (u *Unsqueeze13) GetMaxInputs() int {
return MaxUnsqueeze13Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (u *Unsqueeze13) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (u *Unsqueeze13) String() string {
return "unsqueeze13 operator"
}

// Creates a new array, which is `original` with ones added at the indices specified by `indices`
// `indices` may not contain duplicates, the elements are assumed to be in the range 0 <= x < N
// and should be sorted in increasing order.
Expand Down
55 changes: 15 additions & 40 deletions ops/unsqueeze/unsqueeze_1.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@ import (
"gorgonia.org/tensor"
)

const (
MinUnsqueeze1Inputs = 2
MaxUnsqueeze1Inputs = 2
)

// Unsqueeze1 represents the ONNX unsqueeze operator.
// Unsqueeze1 represents version 1 of the ONNX unsqueeze operator.
type Unsqueeze1 struct {
ops.BaseOperator

axes []int
}

// newUnsqueeze1 creates a new unsqueeze operator.
func newUnsqueeze1() ops.Operator {
return &Unsqueeze1{}
func newUnsqueeze1() *Unsqueeze1 {
return &Unsqueeze1{
BaseOperator: ops.NewBaseOperator(
1,
1,
1,
[][]tensor.Dtype{ops.AllTypes},
"unsqueeze",
),
}
}

// Init initializes the unsqueeze operator.
Expand All @@ -46,18 +51,14 @@ func (u *Unsqueeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

outputRank := len(dataShape) + len(u.axes)

if !ops.AllInRange(u.axes, -outputRank, outputRank-1) {
if !ops.AllInRange(u.axes, 0, outputRank-1) {
return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank)
}

// negative entries should be offset by the rank of the output tensor
// i.e. -1 -> outputRank - 1, -outputrank -> 0
ops.OffsetArrayIfNegative(u.axes, outputRank)

sort.Ints(u.axes)

if ops.HasDuplicates(u.axes) {
return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u)
return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator)
}

newShape := insertOnes(dataShape, u.axes)
Expand All @@ -71,29 +72,3 @@ func (u *Unsqueeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

return []tensor.Tensor{out}, err
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (u *Unsqueeze1) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(u, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (u *Unsqueeze1) GetMinInputs() int {
return MinUnsqueeze1Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (u *Unsqueeze1) GetMaxInputs() int {
return MaxUnsqueeze1Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (u *Unsqueeze1) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (u *Unsqueeze1) String() string {
return "unsqueeze1 operator"
}
49 changes: 14 additions & 35 deletions ops/unsqueeze/unsqueeze_11.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@ import (
"gorgonia.org/tensor"
)

const (
MinUnsqueeze11Inputs = 2
MaxUnsqueeze11Inputs = 2
)

// Unsqueeze11 represents the ONNX unsqueeze operator.
// Unsqueeze11 represents version 11 of the ONNX unsqueeze operator.
type Unsqueeze11 struct {
ops.BaseOperator

axes []int
}

// newUnsqueeze11 creates a new unsqueeze operator.
func newUnsqueeze11() ops.Operator {
return &Unsqueeze11{}
func newUnsqueeze11() *Unsqueeze11 {
return &Unsqueeze11{
BaseOperator: ops.NewBaseOperator(
11,
1,
1,
[][]tensor.Dtype{ops.AllTypes},
"unsqueeze",
),
}
}

// Init initializes the unsqueeze operator.
Expand Down Expand Up @@ -57,7 +62,7 @@ func (u *Unsqueeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
sort.Ints(u.axes)

if ops.HasDuplicates(u.axes) {
return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u)
return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator)
}

newShape := insertOnes(dataShape, u.axes)
Expand All @@ -71,29 +76,3 @@ func (u *Unsqueeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {

return []tensor.Tensor{out}, err
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (u *Unsqueeze11) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(u, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (u *Unsqueeze11) GetMinInputs() int {
return MinUnsqueeze11Inputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (u *Unsqueeze11) GetMaxInputs() int {
return MaxUnsqueeze11Inputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (u *Unsqueeze11) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (u *Unsqueeze11) String() string {
return "unsqueeze11 operator"
}
Loading

0 comments on commit e98fc1b

Please sign in to comment.