diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go index c548fb5..6caadbd 100644 --- a/ops/opset13/batch_normalization.go +++ b/ops/opset13/batch_normalization.go @@ -7,8 +7,10 @@ import ( ) const ( - MinBatchNormalizationInputs = 5 - MaxBatchNormalizationInputs = 5 + MinBatchNormalizationInputs = 5 + MaxBatchNormalizationInputs = 5 + BatchNormalizationDefaultEpsilon = 1e-5 + BatchNormalizationDefaultMomentum = 0.9 ) // BatchNormalization represents the ONNX batchNormalization operator. @@ -21,14 +23,15 @@ type BatchNormalization struct { // newBatchNormalization creates a new batchNormalization operator. func newBatchNormalization() ops.Operator { return &BatchNormalization{ - epsilon: 1e-5, - momentum: 0.9, + epsilon: BatchNormalizationDefaultEpsilon, + momentum: BatchNormalizationDefaultMomentum, } } // Init initializes the batchNormalization operator. func (b *BatchNormalization) Init(n *onnx.NodeProto) error { hasMomentum := false + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "epsilon": @@ -102,7 +105,9 @@ func (b *BatchNormalization) String() string { } func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) { - nSpatialDims := len(X.Shape()) - 2 + nNonSpatialDims := 2 + + nSpatialDims := len(X.Shape()) - nNonSpatialDims if nSpatialDims <= 0 { return scale, bias, mean, variance, nil }