Skip to content

Commit

Permalink
POC: new design for multiple operator versions
Browse files Browse the repository at this point in the history
  • Loading branch information
wisse committed Dec 3, 2024
1 parent 8100156 commit a7eb0a9
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 317 deletions.
55 changes: 55 additions & 0 deletions ops/base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package ops

import (
"fmt"

"gorgonia.org/tensor"
)

// Concrete implementation for shared operator methods
type BaseOperator struct {
name string
version int
minInputs int
maxInputs int
inputTypeConstraints [][]tensor.Dtype
}

func NewBaseOperator(version, minInputs, maxInputs int, inputTypeConstraints [][]tensor.Dtype, name string) BaseOperator {
return BaseOperator{
name: name,
version: version,
minInputs: minInputs,
maxInputs: maxInputs,
inputTypeConstraints: inputTypeConstraints,
}
}

// ValidateInputs validates the inputs for the operator.
func (f BaseOperator) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ValidateInputs(f, inputs)
}

// Version returns the
func (f BaseOperator) Version() int {
return f.version
}

// GetMinInputs returns the minimum number of input tensors.
func (f BaseOperator) GetMinInputs() int {
return f.minInputs
}

// GetMaxInputs returns the maximum number of input tensors.
func (f BaseOperator) GetMaxInputs() int {
return f.maxInputs
}

// GetInputTypeConstraints returns allowed input types.
func (f BaseOperator) GetInputTypeConstraints() [][]tensor.Dtype {
return f.inputTypeConstraints
}

func (f BaseOperator) String() string {
return fmt.Sprintf("%s v%d", f.name, f.version)
}
34 changes: 17 additions & 17 deletions ops/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ const (

type InputError struct {
kind InputErrorKind
operator Operator
Operator BaseOperator
reason string

// Attributes for input type error.
Expand All @@ -116,61 +116,61 @@ type InputError struct {
func (i *InputError) Error() string {
switch i.kind {
case InputErrorType:
return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType)
return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.Operator, i.actualType)
case InputErrorCount:
if i.hasOptionalInputs {
return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.operator.GetMaxInputs(), i.actualCount)
return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.Operator, i.Operator.GetMinInputs(), i.Operator.GetMaxInputs(), i.actualCount)
}

return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount)
return fmt.Sprintf(InvalidInputCountErrTemplate, i.Operator, i.Operator.GetMinInputs(), i.actualCount)
case InputErrorUnsupported:
return fmt.Sprintf(UnsupportedInputErrTemplate, i.operator, i.inputName)
return fmt.Sprintf(UnsupportedInputErrTemplate, i.Operator, i.inputName)
case InputErrorInvalid:
return fmt.Sprintf(InvalidInputErrTemplate, i.operator, i.reason)
return fmt.Sprintf(InvalidInputErrTemplate, i.Operator, i.reason)
default:
return fmt.Sprintf("%s unknown error input error kind %s", i.operator.String(), i.kind)
return fmt.Sprintf("%s unknown error input error kind %s", i.Operator.String(), i.kind)
}
}

func ErrInvalidInputType(inputNumber int, dType string, operator Operator) error {
func ErrInvalidInputType(inputNumber int, dType string, operator BaseOperator) error {
return &InputError{
kind: InputErrorType,
operator: operator,
Operator: operator,
inputNumber: inputNumber,
actualType: dType,
}
}

func ErrInvalidInputCount(actual int, operator Operator) error {
func ErrInvalidInputCount(actual int, operator BaseOperator) error {
return &InputError{
kind: InputErrorCount,
actualCount: actual,
operator: operator,
Operator: operator,
}
}

func ErrInvalidOptionalInputCount(actual int, operator Operator) error {
func ErrInvalidOptionalInputCount(actual int, operator BaseOperator) error {
return &InputError{
kind: InputErrorCount,
hasOptionalInputs: true,
actualCount: actual,
operator: operator,
Operator: operator,
}
}

func ErrUnsupportedInput(inputName string, operator Operator) error {
func ErrUnsupportedInput(inputName string, operator BaseOperator) error {
return &InputError{
kind: InputErrorUnsupported,
inputName: inputName,
operator: operator,
Operator: operator,
}
}

func ErrInvalidInput(reason string, operator Operator) error {
func ErrInvalidInput(reason string, operator BaseOperator) error {
return &InputError{
kind: InputErrorInvalid,
reason: reason,
operator: operator,
Operator: operator,
}
}

Expand Down
24 changes: 0 additions & 24 deletions ops/flatten/flatten_1.go

This file was deleted.

24 changes: 0 additions & 24 deletions ops/flatten/flatten_11.go

This file was deleted.

24 changes: 0 additions & 24 deletions ops/flatten/flatten_13.go

This file was deleted.

162 changes: 0 additions & 162 deletions ops/flatten/flatten_13_test.go

This file was deleted.

Loading

0 comments on commit a7eb0a9

Please sign in to comment.