diff --git a/OperatorFormulas.html b/OperatorFormulas.html index 4949e07..b2e0beb 100644 --- a/OperatorFormulas.html +++ b/OperatorFormulas.html @@ -142,7 +142,7 @@
f(x) = expₑ(x) / sum(expₑ(X))
@@ -3478,22 +3478,23 @@
-function convolve(input, filterWeights, dilations, strides, kernel_shape, pads; output)
+function convolve(input, filterWeights, windowDimensions, padding, dilations, strides)
startPads = pads[0..pads.size/2]
endPads = pads[pads.size/2..pads.size]
- // todo: compute output size
- // output.shape = (input.shape + startPads + endPads) // todo: consider strides and kernel size
+ // TODO: compute output size
+ // output.dimensions = (input.dimensions + startPads + endPads) // todo: consider strides and kernel size
for each outputCoordinate in output coordinates
output[outputCoordinate] = convolveKernel(input, filterWeights, outputCoordinate * strides - startPads, dilations)
endfor
endfunction
function convolveKernel(input, filterWeights, firstInputCoordinate, dilations)
- // 2D example only todo:Figure out what 'group' does and what 'M' is?
+ // 2D example only
+ // TODO:Figure out what 'group' does and what 'M' is?
result = 0
// todo: How do 'M' and 'C' factor into this?
- for y=0..<filterWeights.shape[2]
- for x=0..<filterWeights.shape[3]
+ for y=0..<filterWeights.dimensions[2]
+ for x=0..<filterWeights.dimensions[3]
inputCoordinate = firstInputCoordinate + ([y,x] * dilations)
if (input.contains(inputCoordinate)) // check coordinates within tensor
result += filterWeights[y,x] * input[inputCoordinate]
@@ -3590,7 +3591,7 @@ Operators
function transpose(input, permutationAxes /*gather semantics*/)
assert(permutationAxes.size == input.rank)
rank = input.rank
- for i=0..<rank do output.shape[i] = input.shape[permutationAxes[i]]
+ for i=0..<rank do output.dimensions[i] = input.dimensions[permutationAxes[i]]
outputCoordinate = repeat(rank, 0)
for each inputCoordinate in input coordinates
for i=0..<rank do outputCoordinate[i] = inputCoordinate[permutationAxes[i]]
@@ -3626,12 +3627,11 @@ Operators
Broadcast any single size dimensions up to the output dimension counts.
Similar to NumPy broadcast_to.
-function broadcast(input, shape)
- output.shape = broadcastShape(input.shape, shape)
- N = output.rank
- inputShape = PadLeadingValues(input.shape, N, 1)
+function broadcast(input, targetDimensions)
+ output.dimensions = broadcastDimensions(input.dimensions, targetDimensions)
+ inputShape = padLeadingValues(input.dimensions, output.rank, 1)
for each outputCoordinate in output coordinates
- for i=0..<N do inputCoordinate[i] = iif(inputShape[i] > 1), outputCoordinate[i], 0)
+ for i=0..<output.rank do inputCoordinate[i] = iif(inputShape[i] > 1), outputCoordinate[i], 0)
output[outputCoordinate] = inputData[inputCoordinate]
endfor
endfunction
@@ -3682,9 +3682,9 @@ Operators
// e.g. paddingSize=4 with [H,W] -> [1,1,H,W]
function padLeadingValues(values, paddedSize, padValue)
// Right align. e.g. original dimensions=[H,W], paddedSize=4, padValue=1 -> [1,1,H,W]
- paddedSize = max(padddedSize, values.size)
+ paddingCount = max(paddedSize, values.size) - values.size
paddedValues = values
- paddedValues.prepend(paddingSize - values.size, padValue)
+ paddedValues.prepend(paddingCount, padValue)
return paddedValues
endfunction
function removeOnesFromShape(shape, axes)
+function reshapeDeletingOnes(input, axes)
+ output = input
+ output.dimensions = deleteOnesInDimensions(input.dimensions, axes)
+ return output
+endfunction
+
+function deleteOnesInDimensions(dimensions, axes)
if axes is undefined
- axes = arange(0, shape.size)
- else // !axes.empty
- for i in axes do assert(shape[i] == 1)
- axes = removeDupes(sortDescending(axes))
+ axes = increasingSequence(0, dimensions.size) // Remove all 1's.
+ else
+ assert(allOf(dimensions, (d) => (d == 1)))
+ axes = removeDuplicates(sortDescending(axes))
endif
- newShape = shape
+
+ newDimensions = dimensions
for i in axes // work from back to front
- if newShape[i] == 1
- newShape.deleteAt(i)
+ if newDimensions[i] == 1
+ newDimensions.deleteAt(i)
endif
endfor
- return newShape
-endfunction
-
-function reshapeDeletingOnes(input, axes)
- outputShape = removeOnesFromShape(input.shape, axes)
- output = input
- output.shape = outputShape
- return output
+ return newDimensions
endfunction
-function reshapeInsertingOnes(input, axes)
+function reshapeInsertingOnes(input, axes)
output = input
- outputShape = input.shape
- axes = removeDupes(sort(axes))
+ output.dimensions = insertOnesInDimensions(input.dimensions, axes)
+ return output
+endfunction
+
+function insertOnesInDimensions(dimensions, axes)
+ // Note the axes are relative to their *final* index.
+ // So dimensions = [3,4] with axes = [0,2] yields new dimensions = [1,3,1,4].
+ newDimensions = dimensions
+ axes = removeDuplicates(sort(axes))
for i in axes
- outputShape.insertAt(i, 1)
+ newDimensions.insertAt(i, 1)
endfor
- output.shape = outputShape
+ return newDimensions
+endfunction
+
+
function reshapeFromAxes(input, newRank, axes)
+ assert(input.rank == axes.size)
+ assert(containsUniqueValues(axes))
+ output = input
+ output.dimensions = gatherValues(input.dimensions, axes, newRank, 1)
return output
+endfunction
+
+function gatherValues(values, indices, newValuesSize, fillerValue)
+ newValues = repeat(newValuesSize, fillerValue)
+ for (i, gatherIndex) in indices
+ newValues[i] = values[gatherIndex]
+ endfor
endfunction
function reshapeToAxes(input, newRank, axes)
+ assert(input.rank == axes.size)
+ assert(containsUniqueValues(axes))
+ output = input
+ output.dimensions = scatterValues(input.dimensions, axes, newRank, 1)
+ return output
+endfunction
+
+function scatterValues(values, indices, newValuesSize, fillerValue)
+ newValues = repeat(newValuesSize, fillerValue)
+ for (i, scatterIndex) in indices
+ newValues[scatterIndex] = values[i]
+ endfor
+endfunction
function concatenate(inputs, axis)
sizesAlongAxis = []
for each input in inputs
- sizesAlongAxis.append(input.shape[axis])
+ sizesAlongAxis.append(input.dimensions[axis])
endfor
outputOffsets = cumulativeSum(axisSizes)
for each inputIndex from 0 up to inputs.count
@@ -4054,12 +4145,14 @@ Operators
Gather Elements
Return output tensor the same size as indices, filling with values from input indexed along the axis by indices.
-output = input
-for each coordinate in indices tensor
- inputCoordinate = coordinate
- inputCoordinate[axis] = indices[coordinate]
- output[coordinate] = input[inputCoordinate]
-endfor
+function gatherElements(input, indices, axis)
+ output = new Tensor(input.dataType, indices.dimensions)
+ for each coordinate in indices tensor
+ inputCoordinate = coordinate
+ inputCoordinate[axis] = indices[coordinate]
+ output[coordinate] = input[inputCoordinate]
+ endfor
+endfunction
output[i][j][k] = input[ index[i][j][k] ][j][k] # if dim == 0
output[i][j][k] = input[i][ index[i][j][k] ][k] # if dim == 1
@@ -4132,12 +4225,14 @@ Operators
Scatter Elements
Opposite of gather. Write values from updates into data at the given indices.
If two output element indices overlap, the last write wins in practice.
-output = input
-for each coordinate in indices tensor
- outputCoordinate = coordinate
- outputCoordinate[axis] = indices[coordinate]
- output[outputCoordinate] = updates[coordinate]
-endfor
+function scatterElements(input, indices, updates, axis)
+ output = input
+ for each coordinate in indices tensor
+ outputCoordinate = coordinate
+ outputCoordinate[axis] = indices[coordinate]
+ output[outputCoordinate] = updates[coordinate]
+ endfor
+endfunction
output[ index[i][j][k] ][j][k] = input[i][j][k] # if dim == 0
output[i][ index[i][j][k] ][k] = input[i][j][k] # if dim == 1
@@ -4149,7 +4244,7 @@ Operators
data = [[1, 2, 3, 4, 5]] // data == input
indices = [[1, 3]]
updates = [[11, 21]]
-axis = 0
+axis = 1
output = [[1, 11, 3, 21, 5]]
@@ -4315,7 +4410,7 @@ Operators
function dimensions(input) = input.dimensions
function oneHot(indices, axis, axisLength, values)
+ // Indices and output are broadcast compatible.
+ // Indices has a dimension of size 1 at axis.
+ // Output has a dimension of size axisLength at axis (opposite of reduction).
+ // 1D values[2] contains {offValue, oneValue}.
+ assert(indices.dimensions[axis] == 1)
outputDimensions = indices.dimensions
outputDimensions[axis] = axisLength
- output = broadcast(values[0], outputDimensions)
- TODO: scatter(output, indices, values[1])
+ defaultValues = broadcast(values[0], outputDimensions)
+ return scatterElements(defaultValues, indices, values[1], axis)
endfunction
-function oneHotExpandedOutput(indices, axis, values)
- // output is 1 dimension bigger than input, inserted at the axis.
+function oneHotExpandedOutput(indices, axis, axisLength, values)
+ // Output is 1 dimension bigger than input, inserted at the axis.
+ broadcastCompatibleDimensions = reshapeInsertingOnes(indices, [axis])
+ broadcastCompatibleIndices = reshape(indices, broadcastCompatibleDimensions)
+ return oneHot(broadcastCompatibleIndices, axis, axisLength, values)
endfunction
function topK(input, K, axis) = slice(sortDecreasing(input, axis), starts=0, ends=K, axes=[axis])
function topK(input, axis, axisLength)
+ // Order the entries along an axis, keeping a length of the top K.
+ return slice(sortDecreasing(input, axis), starts=[0], ends=[axisLength], axes=[axis])
+endfunction
function poolGeneric(input, axes, windowDimensions, padding, dilations, reductionFunction, initialValue)
- // TODO:
- // Determine output tensor dimensions.
- // output = new Tensor(input.type, outputDimensions)
+ function poolGeneric(input, axes, windowDimensions, padding, strides, dilations, reductionFunction, initialValue)
+ // Massage all axes-relative parameters to be directly compatible with input/output rank.
+ expandedWindowDimensions = gatherValues(windowDimensions, axes, input.rank, 1)
+ expandedPadding = gatherValues(padding, axes, input.rank, 0)
+ expandedStrides = gatherValues(strides, axes, input.rank, 1)
+ expandedDilations = gatherValues(dilations, axes, input.rank, 1)
+
+ // Compute the output tensor size based on window size/padding/strides/dilations.
+ filterExtents = ((expandedWindowDimensions - 1) * expandedDilations) + 1
+ paddedDimensions = input.dimensions + expandedPadding.leading + expandedPadding.trailing
+ outputDimensions = (paddedDimensions - filterExtents + 1) / expandedStrides
+ output = new Tensor(input.type, outputDimensions)
// Reduce input along active axes.
- for each (outputCoordinate, value) in output
- // TODO: for each input in the window, apply the reduction function
+ for each outputCoordinate in output coordinates
+ // For each input in the window, apply the reduction function
+ outputValue = initialValue
+ for each (inputCoordinate, inputValue) in local input window
+ outputValue = reductionFunction(outputValue, input[inputCoordinate])
+ endfor
+ output[outputCoordinate] = outputValue
endfor
return output
endfunction
-function poolGenericWithIndices(input, axes, windowDimensions, padding, dilations, indicesDataType, reductionFunction, initialValue)
- // TODO:
- // Determine output tensor dimensions.
+function poolGenericWithIndices(input, axes, windowDimensions, padding, strides, dilations, indicesDataType, reductionFunction, initialValue)
+ // TODO: Complete...
// output = new Tensor(input.type, outputDimensions)
// indices = new Tensor(indicesDataType, outputDimensions)
- // Reduce input along active axes.
- for each (outputCoordinate, value) in output
- // TODO: for each input in the window, apply the reduction function
- endfor
return (output, indices)
endfunction
@@ -4611,8 +4725,9 @@ Operators
Pooling
Pool Sum
- function poolSum(input, axes, windowDimensions, padding, dilations)
- return poolGeneric(input, axes, windowDimensions, padding, dilations, add, 0)
+ function poolSum(input, axes, windowDimensions, padding, strides, dilations)
+ return poolGeneric(input, axes, windowDimensions, padding, strides, dilations, add, 0)
+ // OR convolve(input, filter = ones(windowDimensions), axes, windowDimensions, padding, strides, dilations)
endfunction
?
?
@@ -4639,12 +4754,10 @@ Operators
Pooling
Pool Average
- function poolAverage(input, axes, windowDimensions, padding, dilations)
+ function poolAverage(input, axes, windowDimensions, padding, strides, dilations)
windowElementCount = elementCountAlongAxes(input.dimensions, axes)
- return div(poolSum(input, axes, windowDimensions, padding, dilations), windowElementCount)
-endfunction
-
-function poolAverage(input, ...) = conv(input, filter = ones(windowDimensions) / windowElementCount, ...)
+ return div(poolSum(input, axes, windowDimensions, padding, strides, dilations), windowElementCount)
+endfunction
averagePool2d
AveragePool
DML_OPERATOR_AVERAGE_POOLING
@@ -4709,8 +4822,8 @@ Operators
Pooling
Pool Maximum
- function poolMaximum(input, axes, windowDimensions, padding, dilations)
- return poolGeneric(input, axes, windowDimensions, padding, dilations, max, -infinity)
+ function poolMaximum(input, axes, windowDimensions, padding, strides, dilations)
+ return poolGeneric(input, axes, windowDimensions, padding, strides, dilations, max, -infinity)
endfunction
maxPool2d
MaxPool
@@ -4795,8 +4908,8 @@ Operators
Pooling
Pool Lebesgue
- function poolLebesgue(input, axes, windowDimensions, padding, dilations, exponent)
- return root(poolSum(pow(input, exponent), axes, windowDimensions, padding, dilations), exponent)
+ function poolLebesgue(input, axes, windowDimensions, padding, strides, dilations, exponent)
+ return root(poolSum(pow(input, exponent), axes, windowDimensions, padding, strides, dilations), exponent)
// y = (x1^p + x2^p + ... + xn^p) ^ (1/p)
endfunction
l2Pool2d
@@ -4908,7 +5021,7 @@ Operators
// Remove reduced dimensions (size 1) from output tensor if desired.
if keepDimensions == false
- outputDimensions = removeOnesFromShape(outputDimensions, axes)
+ outputDimensions = deleteOnesInDimensions(outputDimensions, axes)
endif
output.dimensions = outputDimensions
return output
@@ -5412,10 +5525,19 @@ Operators
Mean Variance Normalization
For each output element, subtract the mean, and divide by standard deviation.
-MeanVarianceNormalization(input) = (input - mean) / standardDeviation
-MeanVarianceNormalization(input) = (input - mean) / sqrt(variance + epsilon)
-MeanVarianceNormalization(input) = (input - mean(X)) / sqrt(mean((input - mean(input))^2))
-MeanVarianceNormalization(input) = (input - mean(X)) / sqrt(mean(input^2) - mean(input)^2)
+function meanVarianceNormalization(input, axes)
+ // = (input - mean) / standardDeviation
+ // = (input - mean) / sqrt(variance + epsilon)
+ // = (input - mean(X)) / sqrt(mean((input - mean(input))^2))
+ // = (input - mean(X)) / sqrt(mean(input^2) - mean(input)^2)
+ centeredInput = sub(input, mean)
+ mean = reduceAverage(input, axes, keepDimensions=true)
+ meanSquared = pow(mean, 2)
+ squareMeaned = reduceAverage(pow(input, 2), axes, keepDimensions=true)
+ variance = sub(squareMeaned, meanSquared)
+ standardDeviation = sqrt(add(varianceEpsilon, epsilon))
+ return div(centeredInput, standardDeviation)
+endfunction
ONNX and NumPy
@@ -5465,16 +5587,27 @@ Operators
Normalization
Spatial Normalization
(independent batch&channel)
- function instanceNormalization(input, scale, bias)
- reshapedBias = reshape(bias, [batch size, C axis size, spatial dims...])
- axes = [2...input.rank-1] // Exclude axes {0,1} for N and C.
- mean = reduceAverage(input, axes, keepDimensions=true)
- variance = reduceAverage(pow(sub(input, mean), 2), axes, keepDimensions=true)
- return add(mul(scale, meanVarianceNormalization(axes)), reshapedBias)
+ function instanceNormalization(input, scale, bias, axes)
+ // Generic version
+ // Applies: DirectML
+ assert(isBroadcastCompatible(reshapedScale.dimensions, input.dimensions))
+ assert(isBroadcastCompatible(reshapedBias.dimensions, input.dimensions))
+
// scale * (input - mean) / sqrt(variance + epsilon) + reshapedBias
+ return add(mul(scale, meanVarianceNormalization(input, axes)), bias)
+endfunction
+
+function instanceNormalizationSpatialDimensions(input, scale, bias)
+ // Applies: ONNX
+ axes = [2...input.rank-1] // Exclude axes {0,1} for N and C.
+ // 1D scale is coerced to [batch size, C axis size, spatial dims...]
+ // 1D bias is coerced to [batch size, C axis size, spatial dims...]
+ reshapedScale = reshapeToAxes(scale, input.rank, [1])
+ reshapedBias = reshapeToAxes(bias, input.rank, [1])
+ return instanceNormalization(input, scale, bias, axes)
endfunction
-Mean and variance are computed across spatial dimensions DHW, independently per batch per channel NC:
+Mean and variance are computed across spatial dimensions DHW, independently per batch & channel (NC):
axes = [2,3, ..., inputRank-1] // Exclude axes {0,1}
mean = (x0 + x1 + …) / xn;
@@ -5517,9 +5650,12 @@ Operators
Normalization
Channel&Spatial Normalization
(independent leading batches) ref
- function LayerNormalization(X, scale, bias)
- return scale * (X - mean) / sqrt(Variance + epsilon) + reshape(Bias, [batch size, C axis size, spatial dims...])
-endfunction
+ function layerNormalization(input, scale, bias, firstAxis)
+ axes = [firstAxis...input.rank - 1]
+ // Scale and bias are expected to already be broadcast-compatible with input.
+ return add(mul(scale, meanVarianceNormalization(input, axes)), bias)
+ // scale * (input - mean) / sqrt(variance + epsilon) + bias
+endfunction
Mean and variance are computed across all dimensions from and after axis, independently per leading batches:
@@ -5616,10 +5752,11 @@ Operators
function localResponseNormalizationSquare(input, axes, windowLength, scale, bias, exponent)
// Only handles square reduction windows.
- windowDimensions = repeat(windowLength, axes.size)
+ windowDimensions = repeat(axes.size, [windowLength])
leadingPadding = floor((windowLength - 1) / 2) // Center halfway around sliding window
trailingPadding = ceil((windowLength - 1) / 2) // Center halfway around sliding window
padding = repeat(axes.size * 2, [leadingPadding, trailingPadding])
+
return localResponseNormalization(input, axes, windowDimensions, padding, scale, bias, exponent)
endfunction
@@ -6163,7 +6300,7 @@ Operators
Elementwise Math (Deleted)
Image Scaler
- f(X) = add(mul(X, scale), unsqueeze(biasTensor, [0, 2, 3])) // reshape bias to [1,C,1,1]
+ f(X) = add(mul(X, scale), reshapeToAxes(biasTensor, X.rank, [1])) // reshape bias to [1,C,1,1]
f(x, scale, bias) = x * scale + bias
?
ImageScaler
@@ -6288,7 +6425,7 @@ Operators
// Read the surrounding samples from around the interpolated point.
for sampleIndex = 0..Operators
``` c
numberOfElements = max(ceil((limit - start) / delta ), 0)
- output.shape = [numberOfElements]
+ output.dimensions = [numberOfElements]
for i = 0..