Skip to content

Commit

Permalink
🚀 perf: Convolution2D (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored Feb 28, 2024
1 parent 3d3191d commit 192f994
Show file tree
Hide file tree
Showing 7 changed files with 1,077 additions and 33 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ All notable changes to this project will be documented in this file.

## [unreleased]

🪜 **feat:** LayerCAM2D -> VQGrad2D, LayerCAMSeq -> VQGradSeq ([#114](https://github.com/owkin/GrAIdient/pull/114))\
🚀 **perf:** Convolution2D ([118](https://github.com/owkin/GrAIdient/pull/118))\
🪜 **feat:** LayerCAM2D -> VQGrad2D, LayerCAMSeq -> VQGradSeq ([#117](https://github.com/owkin/GrAIdient/pull/117))\
⚙️ **core:** GELU vs GELUApprox ([113](https://github.com/owkin/GrAIdient/pull/113))\
🚀 **perf:** QuerySelf & ValueSelf ([112](https://github.com/owkin/GrAIdient/pull/112))\
🚀 **perf:** benchmark ViT base model ([111](https://github.com/owkin/GrAIdient/pull/111))\
Expand Down
105 changes: 80 additions & 25 deletions Sources/GrAIdient/Layer2D/Convolution2D.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1373,8 +1373,21 @@ public class Convolution2D: BN2D, LayerWeightInit
UInt32(weightHeight)]
let pNbBatch: [UInt32] = [UInt32(batchSize)]

let kernel: String
let coeff: Int
if forwardKernel == "convForward" && nbChannels % 16 == 0
{
kernel = "conv16Forward"
coeff = 16
}
else
{
kernel = forwardKernel
coeff = 1
}

let command = MetalKernel.get.createCommand(
forwardKernel, deviceID: deviceID
kernel, deviceID: deviceID
)
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(_wBuffers.w.metal, atIndex: 1)
Expand All @@ -1390,7 +1403,7 @@ public class Convolution2D: BN2D, LayerWeightInit
command.setBuffer(outs.metal, atIndex: 11)

command.dispatchThreads(
width: nbChannels * width,
width: (nbChannels / coeff) * width,
height: batchSize * height
)
command.enqueue()
Expand Down Expand Up @@ -1556,8 +1569,21 @@ public class Convolution2D: BN2D, LayerWeightInit
let pNbBatch: [UInt32] = [UInt32(batchSize)]
let pDirty: [UInt32] = layerPrev.dirty ? [1] : [0]

let kernel: String
let coeff: Int
if backwardKernel == "convBackward" && nbChannelsPrev % 16 == 0
{
kernel = "conv16Backward"
coeff = 16
}
else
{
kernel = backwardKernel
coeff = 1
}

let command = MetalKernel.get.createCommand(
backwardKernel, deviceID: deviceID
kernel, deviceID: deviceID
)
command.setBuffer(delta.metal, atIndex: 0)
command.setBuffer(_wBuffers.w.metal, atIndex: 1)
Expand All @@ -1573,7 +1599,7 @@ public class Convolution2D: BN2D, LayerWeightInit
command.setBuffer(layerPrev.delta.metal, atIndex: 11)

command.dispatchThreads(
width: nbChannelsPrev * layerPrev.width,
width: (nbChannelsPrev / coeff) * layerPrev.width,
height: batchSize * layerPrev.height
)
command.enqueue()
Expand Down Expand Up @@ -1609,27 +1635,56 @@ public class Convolution2D: BN2D, LayerWeightInit
var command: MetalCommand
if GrAI.Gradient.batch
{
command = MetalKernel.get.createCommand(
batchDerWeightsKernel, deviceID: deviceID
)
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(delta.metal, atIndex: 1)
command.setBytes(pStart, atIndex: 2)
command.setBytes(pStride, atIndex: 3)
command.setBytes(pNbChannels, atIndex: 4)
command.setBytes(pNbChannelsPrev, atIndex: 5)
command.setBytes(pDimensions, atIndex: 6)
command.setBytes(pDimensionsPrev, atIndex: 7)
command.setBytes(pDimWeights, atIndex: 8)
command.setBytes(pNbBatch, atIndex: 9)
command.setBytes(pAccumulate, atIndex: 10)
command.setBuffer(_wBuffers.g.metal, atIndex: 11)

command.dispatchThreads(
width: nbChannels * weightWidth,
height: nbChannelsPrev * weightHeight
)
command.enqueue()
if batchDerWeightsKernel == "convBatchDerWeights" &&
_stride == 1 &&
layerPrev.width == width &&
layerPrev.height == height &&
weightWidth == 3 && weightHeight == 3 &&
height % 2 == 0 && width % 4 == 0
{
command = MetalKernel.get.createCommand(
"conv34BatchDerWeights", deviceID: deviceID
)
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(delta.metal, atIndex: 1)
command.setBytes(pNbChannels, atIndex: 2)
command.setBytes(pNbChannelsPrev, atIndex: 3)
command.setBytes(pDimensions, atIndex: 4)
command.setBytes(pDimensionsPrev, atIndex: 5)
command.setBytes(pNbBatch, atIndex: 6)
command.setBytes(pAccumulate, atIndex: 7)
command.setBuffer(_wBuffers.g.metal, atIndex: 8)

command.dispatchThreads(
width: nbChannels,
height: nbChannelsPrev
)
command.enqueue()
}
else
{
command = MetalKernel.get.createCommand(
batchDerWeightsKernel, deviceID: deviceID
)
command.setBuffer(layerPrev.outs.metal, atIndex: 0)
command.setBuffer(delta.metal, atIndex: 1)
command.setBytes(pStart, atIndex: 2)
command.setBytes(pStride, atIndex: 3)
command.setBytes(pNbChannels, atIndex: 4)
command.setBytes(pNbChannelsPrev, atIndex: 5)
command.setBytes(pDimensions, atIndex: 6)
command.setBytes(pDimensionsPrev, atIndex: 7)
command.setBytes(pDimWeights, atIndex: 8)
command.setBytes(pNbBatch, atIndex: 9)
command.setBytes(pAccumulate, atIndex: 10)
command.setBuffer(_wBuffers.g.metal, atIndex: 11)

command.dispatchThreads(
width: nbChannels * weightWidth,
height: nbChannelsPrev * weightHeight
)
command.enqueue()
}

if _updateBiases
{
Expand Down
Loading

0 comments on commit 192f994

Please sign in to comment.