Skip to content

Commit

Permalink
Rewrote abs operator so it shares code
Browse files Browse the repository at this point in the history
  • Loading branch information
wisse committed Dec 3, 2024
1 parent a7eb0a9 commit 1ae421c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 141 deletions.
44 changes: 44 additions & 0 deletions ops/abs/abs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package abs

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var absTypeConstraint = [][]tensor.Dtype{
{tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}

// Abs represents the ONNX abs operator.
type Abs struct {
ops.BaseOperator
}

// newAbs creates a new abs operator.
func newAbs(version int, typeConstraint [][]tensor.Dtype) *Abs {
return &Abs{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraint,
"abs",
),
}
}

// Init initializes the abs operator.
func (a *Abs) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the abs operator.
func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := tensor.Abs(inputs[0])
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}
63 changes: 0 additions & 63 deletions ops/abs/abs_13.go

This file was deleted.

63 changes: 0 additions & 63 deletions ops/abs/abs_6.go

This file was deleted.

Loading

0 comments on commit 1ae421c

Please sign in to comment.