From 18ca63b466b06231df26e049084d11d72d16355d Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Tue, 9 Jan 2024 14:58:18 +0100 Subject: [PATCH] Fix lint --- ops/opset13/batch_normalization.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go index c548fb5..6caadbd 100644 --- a/ops/opset13/batch_normalization.go +++ b/ops/opset13/batch_normalization.go @@ -7,8 +7,10 @@ import ( ) const ( - MinBatchNormalizationInputs = 5 - MaxBatchNormalizationInputs = 5 + MinBatchNormalizationInputs = 5 + MaxBatchNormalizationInputs = 5 + BatchNormalizationDefaultEpsilon = 1e-5 + BatchNormalizationDefaultMomentum = 0.9 ) // BatchNormalization represents the ONNX batchNormalization operator. @@ -21,14 +23,15 @@ type BatchNormalization struct { // newBatchNormalization creates a new batchNormalization operator. func newBatchNormalization() ops.Operator { return &BatchNormalization{ - epsilon: 1e-5, - momentum: 0.9, + epsilon: BatchNormalizationDefaultEpsilon, + momentum: BatchNormalizationDefaultMomentum, } } // Init initializes the batchNormalization operator. func (b *BatchNormalization) Init(n *onnx.NodeProto) error { hasMomentum := false + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "epsilon": @@ -102,7 +105,9 @@ func (b *BatchNormalization) String() string { } func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) { - nSpatialDims := len(X.Shape()) - 2 + nNonSpatialDims := 2 + + nSpatialDims := len(X.Shape()) - nNonSpatialDims if nSpatialDims <= 0 { return scale, bias, mean, variance, nil }