Skip to content

Commit

Permalink
Add LinearRegressor operator (#184)
Browse files Browse the repository at this point in the history
* WIP on LinearRegressor

* Added tests for linear regressor

* Added test descriptions and docstring

* Do not export constants

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Dec 14, 2023
1 parent 5c6b259 commit c02923d
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 0 deletions.
117 changes: 117 additions & 0 deletions ops/opset13/linear_regressor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package opset13

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

const (
MinLinearRegressorInputs = 1
MaxLinearRegressorInputs = 1
)

// PostTransformOption describes all possible post transform options for the
// linear regressor operator.
type postTransformOption string

const (
noTransform postTransformOption = "NONE"
softmaxTransform postTransformOption = "SOFTMAX"
logisticTransform postTransformOption = "LOGISTIC"
softmaxZeroTransform postTransformOption = "SOFTMAX_ZERO"
probitTransform postTransformOption = "PROBIT"
)

// LinearRegressor represents the ONNX-ml linearRegressor operator.
type LinearRegressor struct {
coefficients tensor.Tensor
intercepts tensor.Tensor
postTransform postTransformOption
targets int
}

// newLinearRegressor creates a new linearRegressor operator.
func newLinearRegressor() ops.Operator {
return &LinearRegressor{
postTransform: noTransform,
targets: 1,
}
}

// Init initializes the linearRegressor operator.
func (l *LinearRegressor) Init(n *onnx.NodeProto) error {
for _, attr := range n.GetAttribute() {
switch attr.GetName() {
case "coefficients":
floats := attr.GetFloats()
l.coefficients = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats))
case "intercepts":
floats := attr.GetFloats()
l.intercepts = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats))
case "post_transform":
return ops.ErrUnsupportedAttribute(attr.GetName(), l)
case "targets":
l.targets = int(attr.GetI())
default:
return ops.ErrInvalidAttribute(attr.GetName(), l)
}
}

err := l.coefficients.Reshape(l.targets, ops.NElements(l.coefficients.Shape()...)/l.targets)
if err != nil {
return err
}

return l.coefficients.T()
}

// Apply applies the linearRegressor operator.
func (l *LinearRegressor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
X := inputs[0]

result, err := tensor.MatMul(X, l.coefficients)
if err != nil {
return nil, err
}

result, intercepts, err := ops.UnidirectionalBroadcast(result, l.intercepts)
if err != nil {
return nil, err
}

Y, err := tensor.Add(result, intercepts)
if err != nil {
return nil, err
}

return []tensor.Tensor{Y}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (l *LinearRegressor) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(l, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (l *LinearRegressor) GetMinInputs() int {
return MinLinearRegressorInputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (l *LinearRegressor) GetMaxInputs() int {
return MaxLinearRegressorInputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (l *LinearRegressor) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{
{tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (l *LinearRegressor) String() string {
return "linearRegressor operator"
}
196 changes: 196 additions & 0 deletions ops/opset13/linear_regressor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestLinearRegressorInit(t *testing.T) {
linearRegressor := &LinearRegressor{}
err := linearRegressor.Init(LinearRegressorOnnxNodeProtoFixture())

assert.Nil(t, err)
assert.Equal(t, []float32{1.5, 2.5, 3.5}, linearRegressor.coefficients.Data())
assert.Equal(t, []float32{0.5}, linearRegressor.intercepts.Data())
assert.Equal(t, 1, linearRegressor.targets)
}

func TestLinearRegressorInitFailUnsupportedAttribute(t *testing.T) {
linearRegressor := &LinearRegressor{}
err := linearRegressor.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "post_transform"}, {Name: "Another"}}})

expected := ops.ErrUnsupportedAttribute("post_transform", linearRegressor)
assert.Equal(t, expected, err)
}

func TestLinearRegressorInitFailInvalidAttribute(t *testing.T) {
linearRegressor := &LinearRegressor{}
err := linearRegressor.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "much_invalid"}}})

expected := ops.ErrInvalidAttribute("much_invalid", linearRegressor)
assert.Equal(t, expected, err)
}

func TestLinearRegressor(t *testing.T) {
tests := []struct {
attrs []*onnx.AttributeProto
shape []int
backing []float32
expectedShape tensor.Shape
expectedBacking []float32
description string
}{
{
[]*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{-0.45977323}},
{Name: "intercepts", Floats: []float32{0.21509616}},
{Name: "targets", I: 1},
},
[]int{1, 1},
[]float32{0.7777024},
[]int{1, 1},
[]float32{-0.14247058},
"linear regressor with 1 input and 1 output variable, 1 sample",
},
{
[]*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{-0.45977323}},
{Name: "intercepts", Floats: []float32{0.21509616}},
{Name: "targets", I: 1},
},
[]int{5, 1},
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011},
[]int{5, 1},
[]float32{-0.14247058, 0.105881065, -0.16388504, -0.22892947, -0.23207982},
"linear regressor with 1 input and 1 output variable, 5 samples",
},
{
[]*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}},
{Name: "intercepts", Floats: []float32{-0.43156273}},
{Name: "targets", I: 1},
},
[]int{1, 3},
[]float32{0.7777024, 0.23754121, 0.82427853},
[]int{1, 1},
[]float32{0.039368242},
"linear regressor with 3 inputs and 1 output variable, 1 sample",
},
{
[]*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}},
{Name: "intercepts", Floats: []float32{-0.43156273}},
{Name: "targets", I: 1},
},
[]int{2, 3},
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011, 0.45344925},
[]int{2, 1},
[]float32{0.039368242, 0.14766997},
"linear regressor with 3 inputs and 1 output variable, 2 samples",
},
{
[]*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{
0.5384742, 0.36729308, 0.13292366, -0.03843413,
0.28054297, -0.27832435, 0.4381632, 0.00726224,
-0.64418418, -0.35812317, 0.69767598, 0.12989015,
}},
{Name: "intercepts", Floats: []float32{-0.37036705, -0.34072968, 0.05487297}},
{Name: "targets", I: 3},
},
[]int{1, 4},
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492},
[]int{1, 3},
[]float32{0.20810121, 0.17951778, 0.16934107},
"linear regressor with 4 input and 3 output variables, 1 samples",
},
{
[]*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{
0.5384742, 0.36729308, 0.13292366, -0.03843413,
0.28054297, -0.27832435, 0.4381632, 0.00726224,
-0.64418418, -0.35812317, 0.69767598, 0.12989015,
}},
{Name: "intercepts", Floats: []float32{-0.37036705, -0.34072968, 0.05487297}},
{Name: "targets", I: 3},
},
[]int{2, 4},
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011, 0.45344925, 0.60904247, 0.7755265},
[]int{2, 3},
[]float32{0.20810121, 0.17951778, 0.16934107, 0.37105185, 0.0784128, -0.20840444},
"linear regressor with 4 input and 3 output variables, 2 samples",
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

linearRegressor := newLinearRegressor()
err := linearRegressor.Init(&onnx.NodeProto{Attribute: test.attrs})
assert.Nil(t, err, test.description)

res, err := linearRegressor.Apply(inputs)
assert.Nil(t, err, test.description)
assert.Equal(t, test.expectedShape, res[0].Shape(), test.description)
assert.Equal(t, test.expectedBacking, res[0].Data(), test.description)
}
}

func TestInputValidationLinearRegressor(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{ops.TensorWithBackingFixture([]int32{1, 2}, 2)},
nil,
},
{
[]tensor.Tensor{ops.TensorWithBackingFixture([]int64{1, 2}, 2)},
nil,
},
{
[]tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)},
nil,
},
{
[]tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &LinearRegressor{}),
},
{
[]tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)},
ops.ErrInvalidInputType(0, "int", &LinearRegressor{}),
},
}

for _, test := range tests {
linearRegressor := &LinearRegressor{}
validated, err := linearRegressor.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}

func LinearRegressorOnnxNodeProtoFixture() *onnx.NodeProto {
return &onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "coefficients", Floats: []float32{1.5, 2.5, 3.5}},
{Name: "intercepts", Floats: []float32{0.5}},
{Name: "targets", I: 1},
},
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var operators13 = map[string]func() ops.Operator{
"GRU": newGRU,
"Less": newLess,
"LessOrEqual": newLessOrEqual,
"LinearRegressor": newLinearRegressor,
"LSTM": newLSTM,
"MatMul": newMatMul,
"Mul": newMul,
Expand Down

0 comments on commit c02923d

Please sign in to comment.