Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Jan 9, 2024
1 parent 5465c2c commit 18ca63b
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ops/opset13/batch_normalization.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
)

const (
MinBatchNormalizationInputs = 5
MaxBatchNormalizationInputs = 5
MinBatchNormalizationInputs = 5
MaxBatchNormalizationInputs = 5
BatchNormalizationDefaultEpsilon = 1e-5
BatchNormalizationDefaultMomentum = 0.9
)

// BatchNormalization represents the ONNX batchNormalization operator.
Expand All @@ -21,14 +23,15 @@ type BatchNormalization struct {
// newBatchNormalization creates a new batchNormalization operator.
func newBatchNormalization() ops.Operator {
return &BatchNormalization{
epsilon: 1e-5,
momentum: 0.9,
epsilon: BatchNormalizationDefaultEpsilon,
momentum: BatchNormalizationDefaultMomentum,
}
}

// Init initializes the batchNormalization operator.
func (b *BatchNormalization) Init(n *onnx.NodeProto) error {
hasMomentum := false

for _, attr := range n.GetAttribute() {
switch attr.GetName() {
case "epsilon":
Expand Down Expand Up @@ -102,7 +105,9 @@ func (b *BatchNormalization) String() string {
}

func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) {
nSpatialDims := len(X.Shape()) - 2
nNonSpatialDims := 2

nSpatialDims := len(X.Shape()) - nNonSpatialDims
if nSpatialDims <= 0 {
return scale, bias, mean, variance, nil
}
Expand Down

0 comments on commit 18ca63b

Please sign in to comment.