-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* WIP on LinearRegressor * Added tests for linear regressor * Added test descriptions and docstring * Do not export constants --------- Co-authored-by: Swopper050 <[email protected]>
- Loading branch information
1 parent
5c6b259
commit c02923d
Showing
3 changed files
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package opset13 | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
const ( | ||
MinLinearRegressorInputs = 1 | ||
MaxLinearRegressorInputs = 1 | ||
) | ||
|
||
// PostTransformOption describes all possible post transform options for the | ||
// linear regressor operator. | ||
type postTransformOption string | ||
|
||
const ( | ||
noTransform postTransformOption = "NONE" | ||
softmaxTransform postTransformOption = "SOFTMAX" | ||
logisticTransform postTransformOption = "LOGISTIC" | ||
softmaxZeroTransform postTransformOption = "SOFTMAX_ZERO" | ||
probitTransform postTransformOption = "PROBIT" | ||
) | ||
|
||
// LinearRegressor represents the ONNX-ml linearRegressor operator. | ||
type LinearRegressor struct { | ||
coefficients tensor.Tensor | ||
intercepts tensor.Tensor | ||
postTransform postTransformOption | ||
targets int | ||
} | ||
|
||
// newLinearRegressor creates a new linearRegressor operator. | ||
func newLinearRegressor() ops.Operator { | ||
return &LinearRegressor{ | ||
postTransform: noTransform, | ||
targets: 1, | ||
} | ||
} | ||
|
||
// Init initializes the linearRegressor operator. | ||
func (l *LinearRegressor) Init(n *onnx.NodeProto) error { | ||
for _, attr := range n.GetAttribute() { | ||
switch attr.GetName() { | ||
case "coefficients": | ||
floats := attr.GetFloats() | ||
l.coefficients = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) | ||
case "intercepts": | ||
floats := attr.GetFloats() | ||
l.intercepts = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) | ||
case "post_transform": | ||
return ops.ErrUnsupportedAttribute(attr.GetName(), l) | ||
case "targets": | ||
l.targets = int(attr.GetI()) | ||
default: | ||
return ops.ErrInvalidAttribute(attr.GetName(), l) | ||
} | ||
} | ||
|
||
err := l.coefficients.Reshape(l.targets, ops.NElements(l.coefficients.Shape()...)/l.targets) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return l.coefficients.T() | ||
} | ||
|
||
// Apply applies the linearRegressor operator. | ||
func (l *LinearRegressor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
X := inputs[0] | ||
|
||
result, err := tensor.MatMul(X, l.coefficients) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
result, intercepts, err := ops.UnidirectionalBroadcast(result, l.intercepts) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
Y, err := tensor.Add(result, intercepts) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return []tensor.Tensor{Y}, nil | ||
} | ||
|
||
// ValidateInputs validates the inputs that will be given to Apply for this operator. | ||
func (l *LinearRegressor) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
return ops.ValidateInputs(l, inputs) | ||
} | ||
|
||
// GetMinInputs returns the minimum number of input tensors this operator expects. | ||
func (l *LinearRegressor) GetMinInputs() int { | ||
return MinLinearRegressorInputs | ||
} | ||
|
||
// GetMaxInputs returns the maximum number of input tensors this operator expects. | ||
func (l *LinearRegressor) GetMaxInputs() int { | ||
return MaxLinearRegressorInputs | ||
} | ||
|
||
// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes | ||
// for the corresponding input tensor. | ||
func (l *LinearRegressor) GetInputTypeConstraints() [][]tensor.Dtype { | ||
return [][]tensor.Dtype{ | ||
{tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, | ||
} | ||
} | ||
|
||
// String implements the stringer interface, and can be used to format errors or messages. | ||
func (l *LinearRegressor) String() string { | ||
return "linearRegressor operator" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
package opset13 | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"github.com/stretchr/testify/assert" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
func TestLinearRegressorInit(t *testing.T) { | ||
linearRegressor := &LinearRegressor{} | ||
err := linearRegressor.Init(LinearRegressorOnnxNodeProtoFixture()) | ||
|
||
assert.Nil(t, err) | ||
assert.Equal(t, []float32{1.5, 2.5, 3.5}, linearRegressor.coefficients.Data()) | ||
assert.Equal(t, []float32{0.5}, linearRegressor.intercepts.Data()) | ||
assert.Equal(t, 1, linearRegressor.targets) | ||
} | ||
|
||
func TestLinearRegressorInitFailUnsupportedAttribute(t *testing.T) { | ||
linearRegressor := &LinearRegressor{} | ||
err := linearRegressor.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "post_transform"}, {Name: "Another"}}}) | ||
|
||
expected := ops.ErrUnsupportedAttribute("post_transform", linearRegressor) | ||
assert.Equal(t, expected, err) | ||
} | ||
|
||
func TestLinearRegressorInitFailInvalidAttribute(t *testing.T) { | ||
linearRegressor := &LinearRegressor{} | ||
err := linearRegressor.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "much_invalid"}}}) | ||
|
||
expected := ops.ErrInvalidAttribute("much_invalid", linearRegressor) | ||
assert.Equal(t, expected, err) | ||
} | ||
|
||
func TestLinearRegressor(t *testing.T) { | ||
tests := []struct { | ||
attrs []*onnx.AttributeProto | ||
shape []int | ||
backing []float32 | ||
expectedShape tensor.Shape | ||
expectedBacking []float32 | ||
description string | ||
}{ | ||
{ | ||
[]*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{-0.45977323}}, | ||
{Name: "intercepts", Floats: []float32{0.21509616}}, | ||
{Name: "targets", I: 1}, | ||
}, | ||
[]int{1, 1}, | ||
[]float32{0.7777024}, | ||
[]int{1, 1}, | ||
[]float32{-0.14247058}, | ||
"linear regressor with 1 input and 1 output variable, 1 sample", | ||
}, | ||
{ | ||
[]*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{-0.45977323}}, | ||
{Name: "intercepts", Floats: []float32{0.21509616}}, | ||
{Name: "targets", I: 1}, | ||
}, | ||
[]int{5, 1}, | ||
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011}, | ||
[]int{5, 1}, | ||
[]float32{-0.14247058, 0.105881065, -0.16388504, -0.22892947, -0.23207982}, | ||
"linear regressor with 1 input and 1 output variable, 5 samples", | ||
}, | ||
{ | ||
[]*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}}, | ||
{Name: "intercepts", Floats: []float32{-0.43156273}}, | ||
{Name: "targets", I: 1}, | ||
}, | ||
[]int{1, 3}, | ||
[]float32{0.7777024, 0.23754121, 0.82427853}, | ||
[]int{1, 1}, | ||
[]float32{0.039368242}, | ||
"linear regressor with 3 inputs and 1 output variable, 1 sample", | ||
}, | ||
{ | ||
[]*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}}, | ||
{Name: "intercepts", Floats: []float32{-0.43156273}}, | ||
{Name: "targets", I: 1}, | ||
}, | ||
[]int{2, 3}, | ||
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011, 0.45344925}, | ||
[]int{2, 1}, | ||
[]float32{0.039368242, 0.14766997}, | ||
"linear regressor with 3 inputs and 1 output variable, 2 samples", | ||
}, | ||
{ | ||
[]*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{ | ||
0.5384742, 0.36729308, 0.13292366, -0.03843413, | ||
0.28054297, -0.27832435, 0.4381632, 0.00726224, | ||
-0.64418418, -0.35812317, 0.69767598, 0.12989015, | ||
}}, | ||
{Name: "intercepts", Floats: []float32{-0.37036705, -0.34072968, 0.05487297}}, | ||
{Name: "targets", I: 3}, | ||
}, | ||
[]int{1, 4}, | ||
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492}, | ||
[]int{1, 3}, | ||
[]float32{0.20810121, 0.17951778, 0.16934107}, | ||
"linear regressor with 4 input and 3 output variables, 1 samples", | ||
}, | ||
{ | ||
[]*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{ | ||
0.5384742, 0.36729308, 0.13292366, -0.03843413, | ||
0.28054297, -0.27832435, 0.4381632, 0.00726224, | ||
-0.64418418, -0.35812317, 0.69767598, 0.12989015, | ||
}}, | ||
{Name: "intercepts", Floats: []float32{-0.37036705, -0.34072968, 0.05487297}}, | ||
{Name: "targets", I: 3}, | ||
}, | ||
[]int{2, 4}, | ||
[]float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011, 0.45344925, 0.60904247, 0.7755265}, | ||
[]int{2, 3}, | ||
[]float32{0.20810121, 0.17951778, 0.16934107, 0.37105185, 0.0784128, -0.20840444}, | ||
"linear regressor with 4 input and 3 output variables, 2 samples", | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
inputs := []tensor.Tensor{ | ||
ops.TensorWithBackingFixture(test.backing, test.shape...), | ||
} | ||
|
||
linearRegressor := newLinearRegressor() | ||
err := linearRegressor.Init(&onnx.NodeProto{Attribute: test.attrs}) | ||
assert.Nil(t, err, test.description) | ||
|
||
res, err := linearRegressor.Apply(inputs) | ||
assert.Nil(t, err, test.description) | ||
assert.Equal(t, test.expectedShape, res[0].Shape(), test.description) | ||
assert.Equal(t, test.expectedBacking, res[0].Data(), test.description) | ||
} | ||
} | ||
|
||
func TestInputValidationLinearRegressor(t *testing.T) { | ||
tests := []struct { | ||
inputs []tensor.Tensor | ||
err error | ||
}{ | ||
{ | ||
[]tensor.Tensor{ops.TensorWithBackingFixture([]int32{1, 2}, 2)}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ops.TensorWithBackingFixture([]int64{1, 2}, 2)}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, | ||
nil, | ||
}, | ||
{ | ||
[]tensor.Tensor{}, | ||
ops.ErrInvalidInputCount(0, &LinearRegressor{}), | ||
}, | ||
{ | ||
[]tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, | ||
ops.ErrInvalidInputType(0, "int", &LinearRegressor{}), | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
linearRegressor := &LinearRegressor{} | ||
validated, err := linearRegressor.ValidateInputs(test.inputs) | ||
|
||
assert.Equal(t, test.err, err) | ||
|
||
if test.err == nil { | ||
assert.Equal(t, test.inputs, validated) | ||
} | ||
} | ||
} | ||
|
||
func LinearRegressorOnnxNodeProtoFixture() *onnx.NodeProto { | ||
return &onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "coefficients", Floats: []float32{1.5, 2.5, 3.5}}, | ||
{Name: "intercepts", Floats: []float32{0.5}}, | ||
{Name: "targets", I: 1}, | ||
}, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters