From 6dd84dd01fa7bd7b944e0ae39e51b16d2256c761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Reboud?= Date: Fri, 28 Jun 2024 11:19:59 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(layer=5Fseq):=20QueryCausalSeq?= =?UTF-8?q?=20(#125)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 1 + Sources/GrAIdient/LayerSeq/QuerySeq.swift | 746 ++++++++++++++++++ .../Metal/Kernel/LayerSeqFloat.metal | 124 --- .../GrAIdient/Metal/Kernel/LayerSeqHalf.metal | 124 --- Sources/GrAIdient/Metal/Kernel/NLPFloat.metal | 497 ++++++++++++ Sources/GrAIdient/Metal/Kernel/NLPHalf.metal | 497 ++++++++++++ Sources/GrAIdient/Metal/MetalConfig.swift | 24 +- Sources/GrAIdient/Utils/Serialization.swift | 1 + .../GrAIExamples/Base/python_lib/nlp/model.py | 6 +- Tests/GrAIExamples/NLPExample.swift | 57 +- .../Base/InputSeq/EmbeddingSeqMSE1DCase.swift | 4 +- Tests/GrAITests/NLPTests.swift | 358 +++++++++ 12 files changed, 2173 insertions(+), 266 deletions(-) create mode 100644 Sources/GrAIdient/Metal/Kernel/NLPFloat.metal create mode 100644 Sources/GrAIdient/Metal/Kernel/NLPHalf.metal diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f501fe0..84566f60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ## [unreleased] +✨ **layer_seq:** QueryCausalSeq ([125](https://github.com/owkin/GrAIdient/pull/125))\ ✨ **layer_seq:** RoPESeq ([124](https://github.com/owkin/GrAIdient/pull/124))\ ✨ **layer_seq:** RMSNormSeq ([123](https://github.com/owkin/GrAIdient/pull/123))\ ✨ **layer_seq:** EmbeddingSeq ([122](https://github.com/owkin/GrAIdient/pull/122))\ diff --git a/Sources/GrAIdient/LayerSeq/QuerySeq.swift b/Sources/GrAIdient/LayerSeq/QuerySeq.swift index 180403cb..012fae53 100644 --- a/Sources/GrAIdient/LayerSeq/QuerySeq.swift +++ b/Sources/GrAIdient/LayerSeq/QuerySeq.swift @@ -996,3 +996,749 @@ public class QuerySelfSeq: LayerSeq } } } + +/// +/// Layer with a sequential shape neural structure. +/// +/// This layer computes the causal attention scores between a query layer and a key layer. +/// +public class QueryCausalSeq: LayerMergeSeq +{ + /// Number of heads (groups) of neurons for query. + let _nbHeadsQuery: Int + /// Number of heads (groups) of neurons for key. + let _nbHeadsKey: Int + + private enum Keys: String, CodingKey + { + case nbHeadsQuery + case nbHeadsKey + } + + /// + /// Create a layer with a sequential shape neural structure. + /// + /// - Parameters: + /// - query: Previous layer containing the query to look for. + /// - key: Previous layer containing the keys of reference. + /// - nbHeadsQuery: Number of heads (groups) of neurons for query. + /// - nbHeadsKey: Number of heads (groups) of neurons for key. + /// - params: Contextual parameters linking to the model. + /// + public init(query: LayerSeq, key: LayerSeq, + nbHeadsQuery: Int, nbHeadsKey: Int, + params: GrAI.Model.Params) throws + { + if query.nbNeurons % nbHeadsQuery != 0 + { + throw LayerError.Init(message: + "`nbNeurons` (\(query.nbNeurons)) " + + "should be a multiple of `nbHeadsQuery` (\(nbHeadsQuery))." + ) + } + if key.nbNeurons % nbHeadsKey != 0 + { + throw LayerError.Init(message: + "`nbNeurons` (\(key.nbNeurons)) " + + "should be a multiple of `nbHeadsKey` (\(nbHeadsKey))." + ) + } + if nbHeadsQuery % nbHeadsKey != 0 + { + throw LayerError.Init(message: + "`nbHeadsQuery` should be a multiple of `nbHeadsKey`" + ) + } + if query.nbNeurons / nbHeadsQuery != key.nbNeurons / nbHeadsKey + { + throw LayerError.Init(message: + "`query` and `key` should should have same hidden dimension." + ) + } + if query.sequence != key.sequence + { + throw LayerError.Init(message: "Layer structure error.") + } + + _nbHeadsQuery = nbHeadsQuery + _nbHeadsKey = nbHeadsKey + + super.init(layersPrev: [query, key], + sequence: query.sequence, + nbNeurons: query.sequence * nbHeadsQuery, + params: params) + } + + /// + /// Decode from the disk. + /// + /// Throw an error if reading from the decoder fails, or + /// if the data read is corrupted or otherwise invalid. + /// + /// - Parameter decoder: The decoder to read data from. + /// + public required init(from decoder: Decoder) throws + { + let values = try decoder.container(keyedBy: Keys.self) + _nbHeadsQuery = try values.decode(Int.self, forKey: Keys.nbHeadsQuery) + _nbHeadsKey = try values.decode(Int.self, forKey: Keys.nbHeadsKey) + try super.init(from: decoder) + } + + /// + /// Encode to the disk. + /// + /// If the value fails to encode anything, `encoder` will encode an empty + /// keyed container in its place. + /// + /// Throw an error if any values are invalid for the given + /// encoder's format. + /// + /// - Parameter encoder: The encoder to write data to. + /// + public override func encode(to encoder: Encoder) throws + { + var container = encoder.container(keyedBy: Keys.self) + try container.encode(_nbHeadsQuery, forKey: Keys.nbHeadsQuery) + try container.encode(_nbHeadsKey, forKey: Keys.nbHeadsKey) + try super.encode(to: encoder) + } + + /// + /// Create a layer with same values as this. + /// + /// - Parameters: + /// - mapping: Dictionary allowing to find the layer associated to some id. + /// This dictionary is particularly useful when the different layers cannot access + /// their `layerPrev`. + /// - inPlace: Whether hard resources should be copied as is. + /// + /// - Returns: A new layer. When `inPlace` is false, `initKernel` is + /// necessary in order to recreate hard resources. + /// + public override func copy( + mapping: Dictionary, + inPlace: Bool) -> Layer + { + let context = ModelContext(name: "", curID: 0) + let params = GrAI.Model.Params(context: context) + params.context.curID = id + + var layersPrev = [LayerSeq]() + for idPrev in _idsPrev + { + layersPrev.append(mapping[idPrev] as! LayerSeq) + } + + let layer = try! QueryCausalSeq( + query: layersPrev[0], key: layersPrev[1], + nbHeadsQuery: _nbHeadsQuery, + nbHeadsKey: _nbHeadsKey, + params: params + ) + return layer + } + + /// + /// Initialize state resources in the CPU execution context. + /// + /// We initialize the neurons' state (forward and backward). + /// + public override func checkStateCPU(batchSize: Int) throws + { + if neurons == nil + { + try super.checkStateCPU(batchSize: batchSize) + _encodeCausalityCPU() + } + else + { + try super.checkStateCPU(batchSize: batchSize) + } + } + + /// Update causality scores in the CPU execution context. + private func _encodeCausalityCPU() + { + for elem in 0..= nbBlocks || seq >= sequence) - { - return ; - } - - float position = (float)seqPositions[seq]; - float theta = pow( - 10000.0, - -2.0 * (float)block / (float)size - ); - float mTheta = position * theta; - float cosVal = cos(mTheta); - float sinVal = sin(mTheta); - - uint offset = 2 * block + seq * size; - rotationMatrix[offset] = cosVal; - rotationMatrix[1 + offset] = sinVal; -} - -kernel void RoPESeqForwardFloat( - const device float * outsPrev, - const device float * rotationMatrix, - constant uint & nbHeads, - constant uint & nbNeurons, - constant uint & nbBatch, - constant uint & sequence, - device float * outs, - uint2 id [[ thread_position_in_grid ]]) -{ - uint size = nbNeurons / nbHeads; - uint nbBlocks = size / 2; - - uint head = id[0] / nbBlocks; - uint block = id[0] % nbBlocks; - uint elem = id[1] / sequence; - uint seq = id[1] % sequence; - - if (head >= nbHeads || block >= nbBlocks || - elem >= nbBatch || seq >= sequence) - { - return ; - } - - uint offset1 = 2 * block + seq * size; - uint offset2 = 2 * block + head * size + - nbNeurons * seq + sequence * nbNeurons * elem; - - float cosVal = rotationMatrix[offset1]; - float sinVal = rotationMatrix[1 + offset1]; - - float in1 = outsPrev[offset2]; - float in2 = outsPrev[1 + offset2]; - - float out1 = in1 * cosVal - in2 * sinVal; - float out2 = in1 * sinVal + in2 * cosVal; - - outs[offset2] = out1; - outs[1 + offset2] = out2; -} - -kernel void RoPESeqSeqBackwardFloat( - const device float * delta, - const device float * rotationMatrix, - constant uint & nbHeads, - constant uint & nbNeurons, - constant uint & nbBatch, - constant uint & sequence, - constant uint & dirty, - device float * deltaPrev, - uint2 id [[ thread_position_in_grid ]]) -{ - uint size = nbNeurons / nbHeads; - uint nbBlocks = size / 2; - - uint head = id[0] / nbBlocks; - uint block = id[0] % nbBlocks; - uint elem = id[1] / sequence; - uint seq = id[1] % sequence; - - if (head >= nbHeads || block >= nbBlocks || - elem >= nbBatch || seq >= sequence) - { - return ; - } - - uint offset1 = 2 * block + seq * size; - uint offset2 = 2 * block + head * size + - nbNeurons * seq + sequence * nbNeurons * elem; - - float cosVal = rotationMatrix[offset1]; - float sinVal = rotationMatrix[1 + offset1]; - - float out1 = delta[offset2]; - float out2 = delta[1 + offset2]; - - float in1 = out1 * cosVal + out2 * sinVal; - float in2 = -out1 * sinVal + out2 * cosVal; - - if (dirty) - { - deltaPrev[offset2] = in1; - deltaPrev[1 + offset2] = in2; - } - else - { - deltaPrev[offset2] += in1; - deltaPrev[1 + offset2] += in2; - } -} diff --git a/Sources/GrAIdient/Metal/Kernel/LayerSeqHalf.metal b/Sources/GrAIdient/Metal/Kernel/LayerSeqHalf.metal index 80f86c7d..21a2a7be 100644 --- a/Sources/GrAIdient/Metal/Kernel/LayerSeqHalf.metal +++ b/Sources/GrAIdient/Metal/Kernel/LayerSeqHalf.metal @@ -2743,127 +2743,3 @@ kernel void layerCAMSeqForwardHalf( uint offset = seq + sequence * elem; outs[offset] = sum; } - -kernel void createRoPESeqMatrixHalf( - constant int * seqPositions, - constant uint & nbHeads, - constant uint & nbNeurons, - constant uint & sequence, - device half * rotationMatrix, - uint2 id [[ thread_position_in_grid ]]) -{ - uint size = nbNeurons / nbHeads; - uint nbBlocks = size / 2; - - uint block = id[0]; - uint seq = id[1]; - - if (block >= nbBlocks || seq >= sequence) - { - return ; - } - - float position = (float)seqPositions[seq]; - float theta = pow( - 10000.0, - -2.0 * (float)block / (float)size - ); - float mTheta = position * theta; - float cosVal = cos(mTheta); - float sinVal = sin(mTheta); - - uint offset = 2 * block + seq * size; - rotationMatrix[offset] = cosVal; - rotationMatrix[1 + offset] = sinVal; -} - -kernel void RoPESeqForwardHalf( - const device half * outsPrev, - const device half * rotationMatrix, - constant uint & nbHeads, - constant uint & nbNeurons, - constant uint & nbBatch, - constant uint & sequence, - device half * outs, - uint2 id [[ thread_position_in_grid ]]) -{ - uint size = nbNeurons / nbHeads; - uint nbBlocks = size / 2; - - uint head = id[0] / nbBlocks; - uint block = id[0] % nbBlocks; - uint elem = id[1] / sequence; - uint seq = id[1] % sequence; - - if (head >= nbHeads || block >= nbBlocks || - elem >= nbBatch || seq >= sequence) - { - return ; - } - - uint offset1 = 2 * block + seq * size; - uint offset2 = 2 * block + head * size + - nbNeurons * seq + sequence * nbNeurons * elem; - - half cosVal = rotationMatrix[offset1]; - half sinVal = rotationMatrix[1 + offset1]; - - half in1 = outsPrev[offset2]; - half in2 = outsPrev[1 + offset2]; - - half out1 = in1 * cosVal - in2 * sinVal; - half out2 = in1 * sinVal + in2 * cosVal; - - outs[offset2] = out1; - outs[1 + offset2] = out2; -} - -kernel void RoPESeqSeqBackwardHalf( - const device half * delta, - const device half * rotationMatrix, - constant uint & nbHeads, - constant uint & nbNeurons, - constant uint & nbBatch, - constant uint & sequence, - constant uint & dirty, - device half * deltaPrev, - uint2 id [[ thread_position_in_grid ]]) -{ - uint size = nbNeurons / nbHeads; - uint nbBlocks = size / 2; - - uint head = id[0] / nbBlocks; - uint block = id[0] % nbBlocks; - uint elem = id[1] / sequence; - uint seq = id[1] % sequence; - - if (head >= nbHeads || block >= nbBlocks || - elem >= nbBatch || seq >= sequence) - { - return ; - } - - uint offset1 = 2 * block + seq * size; - uint offset2 = 2 * block + head * size + - nbNeurons * seq + sequence * nbNeurons * elem; - - half cosVal = rotationMatrix[offset1]; - half sinVal = rotationMatrix[1 + offset1]; - - half out1 = delta[offset2]; - half out2 = delta[1 + offset2]; - - half in1 = out1 * cosVal + out2 * sinVal; - half in2 = -out1 * sinVal + out2 * cosVal; - - if (dirty) - { - deltaPrev[offset2] = in1; - deltaPrev[1 + offset2] = in2; - } - else - { - deltaPrev[offset2] += in1; - deltaPrev[1 + offset2] += in2; - } -} diff --git a/Sources/GrAIdient/Metal/Kernel/NLPFloat.metal b/Sources/GrAIdient/Metal/Kernel/NLPFloat.metal new file mode 100644 index 00000000..89ad05c7 --- /dev/null +++ b/Sources/GrAIdient/Metal/Kernel/NLPFloat.metal @@ -0,0 +1,497 @@ +// +// NLPFloat.metal +// GrAIdient +// +// Created by Jean-François Reboud on 25/06/2024. +// + +#include +using namespace metal; + +kernel void createRoPESeqMatrixFloat( + constant int * seqPositions, + constant uint & nbHeads, + constant uint & nbNeurons, + constant uint & sequence, + device float * rotationMatrix, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeurons / nbHeads; + uint nbBlocks = size / 2; + + uint block = id[0]; + uint seq = id[1]; + + if (block >= nbBlocks || seq >= sequence) + { + return ; + } + + float position = (float)seqPositions[seq]; + float theta = pow( + 10000.0, + -2.0 * (float)block / (float)size + ); + float mTheta = position * theta; + float cosVal = cos(mTheta); + float sinVal = sin(mTheta); + + uint offset = 2 * block + seq * size; + rotationMatrix[offset] = cosVal; + rotationMatrix[1 + offset] = sinVal; +} + +kernel void RoPESeqForwardFloat( + const device float * outsPrev, + const device float * rotationMatrix, + constant uint & nbHeads, + constant uint & nbNeurons, + constant uint & nbBatch, + constant uint & sequence, + device float * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeurons / nbHeads; + uint nbBlocks = size / 2; + + uint head = id[0] / nbBlocks; + uint block = id[0] % nbBlocks; + uint elem = id[1] / sequence; + uint seq = id[1] % sequence; + + if (head >= nbHeads || block >= nbBlocks || + elem >= nbBatch || seq >= sequence) + { + return ; + } + + uint offset1 = 2 * block + seq * size; + uint offset2 = 2 * block + head * size + + nbNeurons * seq + sequence * nbNeurons * elem; + + float cosVal = rotationMatrix[offset1]; + float sinVal = rotationMatrix[1 + offset1]; + + float in1 = outsPrev[offset2]; + float in2 = outsPrev[1 + offset2]; + + float out1 = in1 * cosVal - in2 * sinVal; + float out2 = in1 * sinVal + in2 * cosVal; + + outs[offset2] = out1; + outs[1 + offset2] = out2; +} + +kernel void RoPESeqSeqBackwardFloat( + const device float * delta, + const device float * rotationMatrix, + constant uint & nbHeads, + constant uint & nbNeurons, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device float * deltaPrev, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeurons / nbHeads; + uint nbBlocks = size / 2; + + uint head = id[0] / nbBlocks; + uint block = id[0] % nbBlocks; + uint elem = id[1] / sequence; + uint seq = id[1] % sequence; + + if (head >= nbHeads || block >= nbBlocks || + elem >= nbBatch || seq >= sequence) + { + return ; + } + + uint offset1 = 2 * block + seq * size; + uint offset2 = 2 * block + head * size + + nbNeurons * seq + sequence * nbNeurons * elem; + + float cosVal = rotationMatrix[offset1]; + float sinVal = rotationMatrix[1 + offset1]; + + float out1 = delta[offset2]; + float out2 = delta[1 + offset2]; + + float in1 = out1 * cosVal + out2 * sinVal; + float in2 = -out1 * sinVal + out2 * cosVal; + + if (dirty) + { + deltaPrev[offset2] = in1; + deltaPrev[1 + offset2] = in2; + } + else + { + deltaPrev[offset2] += in1; + deltaPrev[1 + offset2] += in2; + } +} + +kernel void encodeCausalityFloat( + constant uint & nbHeadsQuery, + constant uint & nbNeurons, + constant uint & nbBatch, + constant uint & sequence, + device float * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint headQuery = id[0] / sequence; + uint seqK = id[0] % sequence; + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headQuery >= nbHeadsQuery || + seqK >= sequence || seqK <= seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint offset = seqK + headQuery * sequence + + nbNeurons * seqQ + sequence * nbNeurons * elem; + outs[offset] = -1e9; +} + +kernel void queryCausalSeqForwardFloat( + const device float * query, + const device float * key, + constant uint & nbHeadsQuery, + constant uint & nbHeadsKey, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevQuery, + constant uint & nbNeuronsPrevKey, + constant uint & nbBatch, + constant uint & sequence, + device float * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevQuery / nbHeadsQuery; + + uint headQuery = id[0] / sequence; + uint seqK = id[0] % sequence; + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headQuery >= nbHeadsQuery || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + float tmp = 0.0; + + for (uint j=0; j= nbHeadsQuery || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + float4 tmp = 0.0; + + for (uint j=0; j= nbHeadsQuery || j >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + uint depthPrevKey = j + headKey * size; + uint depthPrevQuery = j + headQuery * size; + + float tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offset = seqK + headQuery * sequence + + nbNeurons * seqQ + sequence * nbNeurons * elem; + uint offsetKey = depthPrevKey + + nbNeuronsPrevKey * seqK + sequence * nbNeuronsPrevKey * elem; + + tmp += delta[offset] * key[offsetKey]; + } + tmp /= sqrt((float)size); + + uint offsetQuery = depthPrevQuery + + nbNeuronsPrevQuery * seqQ + sequence * nbNeuronsPrevQuery * elem; + + if (dirty) + { + query[offsetQuery] = tmp; + } + else + { + query[offsetQuery] += tmp; + } +} + +kernel void queryCausalQuerySeq4BackwardFloat( + const device float * delta, + const device float4 * key, + constant uint & nbHeadsQuery, + constant uint & nbHeadsKey, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevQuery, + constant uint & nbNeuronsPrevKey, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device float4 * query, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevQuery / nbHeadsQuery; + + uint headQuery = id[0] / (size / 4); + uint j = id[0] % (size / 4); + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headQuery >= nbHeadsQuery || j * 4 >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + uint depthPrevKey = j * 4 + headKey * size; + uint depthPrevQuery = j * 4 + headQuery * size; + + float4 tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offset = seqK + headQuery * sequence + + nbNeurons * seqQ + sequence * nbNeurons * elem; + uint offsetKey = (depthPrevKey + + nbNeuronsPrevKey * seqK + sequence * nbNeuronsPrevKey * elem) / 4; + + tmp += delta[offset] * key[offsetKey]; + } + tmp /= sqrt((float)size); + + uint offsetQuery = (depthPrevQuery + + nbNeuronsPrevQuery * seqQ + sequence * nbNeuronsPrevQuery * elem) / 4; + + if (dirty) + { + query[offsetQuery] = tmp; + } + else + { + query[offsetQuery] += tmp; + } +} + +kernel void queryCausalKeySeqBackwardFloat( + const device float * delta, + const device float * query, + constant uint & nbHeadsQuery, + constant uint & nbHeadsKey, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevQuery, + constant uint & nbNeuronsPrevKey, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device float * key, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevKey / nbHeadsKey; + + uint headKey = id[0] / size; + uint j = id[0] % size; + uint elem = id[1] / sequence; + uint seqK = id[1] % sequence; + + if (headKey >= nbHeadsKey || j >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint nbBlocksHead = nbHeadsQuery == nbHeadsKey ? + 1 : nbHeadsQuery / nbHeadsKey; + uint depthPrevKey = j + headKey * size; + + float tmp = 0.0; + for (uint blockHead=0; blockHead= nbHeadsKey || j * 4 >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint nbBlocksHead = nbHeadsQuery == nbHeadsKey ? + 1 : nbHeadsQuery / nbHeadsKey; + uint depthPrevKey = j * 4 + headKey * size; + + float4 tmp = 0.0; + for (uint blockHead=0; blockHead +using namespace metal; + +kernel void createRoPESeqMatrixHalf( + constant int * seqPositions, + constant uint & nbHeads, + constant uint & nbNeurons, + constant uint & sequence, + device half * rotationMatrix, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeurons / nbHeads; + uint nbBlocks = size / 2; + + uint block = id[0]; + uint seq = id[1]; + + if (block >= nbBlocks || seq >= sequence) + { + return ; + } + + float position = (float)seqPositions[seq]; + float theta = pow( + 10000.0, + -2.0 * (float)block / (float)size + ); + float mTheta = position * theta; + float cosVal = cos(mTheta); + float sinVal = sin(mTheta); + + uint offset = 2 * block + seq * size; + rotationMatrix[offset] = cosVal; + rotationMatrix[1 + offset] = sinVal; +} + +kernel void RoPESeqForwardHalf( + const device half * outsPrev, + const device half * rotationMatrix, + constant uint & nbHeads, + constant uint & nbNeurons, + constant uint & nbBatch, + constant uint & sequence, + device half * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeurons / nbHeads; + uint nbBlocks = size / 2; + + uint head = id[0] / nbBlocks; + uint block = id[0] % nbBlocks; + uint elem = id[1] / sequence; + uint seq = id[1] % sequence; + + if (head >= nbHeads || block >= nbBlocks || + elem >= nbBatch || seq >= sequence) + { + return ; + } + + uint offset1 = 2 * block + seq * size; + uint offset2 = 2 * block + head * size + + nbNeurons * seq + sequence * nbNeurons * elem; + + half cosVal = rotationMatrix[offset1]; + half sinVal = rotationMatrix[1 + offset1]; + + half in1 = outsPrev[offset2]; + half in2 = outsPrev[1 + offset2]; + + half out1 = in1 * cosVal - in2 * sinVal; + half out2 = in1 * sinVal + in2 * cosVal; + + outs[offset2] = out1; + outs[1 + offset2] = out2; +} + +kernel void RoPESeqSeqBackwardHalf( + const device half * delta, + const device half * rotationMatrix, + constant uint & nbHeads, + constant uint & nbNeurons, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device half * deltaPrev, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeurons / nbHeads; + uint nbBlocks = size / 2; + + uint head = id[0] / nbBlocks; + uint block = id[0] % nbBlocks; + uint elem = id[1] / sequence; + uint seq = id[1] % sequence; + + if (head >= nbHeads || block >= nbBlocks || + elem >= nbBatch || seq >= sequence) + { + return ; + } + + uint offset1 = 2 * block + seq * size; + uint offset2 = 2 * block + head * size + + nbNeurons * seq + sequence * nbNeurons * elem; + + half cosVal = rotationMatrix[offset1]; + half sinVal = rotationMatrix[1 + offset1]; + + half out1 = delta[offset2]; + half out2 = delta[1 + offset2]; + + half in1 = out1 * cosVal + out2 * sinVal; + half in2 = -out1 * sinVal + out2 * cosVal; + + if (dirty) + { + deltaPrev[offset2] = in1; + deltaPrev[1 + offset2] = in2; + } + else + { + deltaPrev[offset2] += in1; + deltaPrev[1 + offset2] += in2; + } +} + +kernel void encodeCausalityHalf( + constant uint & nbHeadsQuery, + constant uint & nbNeurons, + constant uint & nbBatch, + constant uint & sequence, + device half * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint headQuery = id[0] / sequence; + uint seqK = id[0] % sequence; + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headQuery >= nbHeadsQuery || + seqK >= sequence || seqK <= seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint offset = seqK + headQuery * sequence + + nbNeurons * seqQ + sequence * nbNeurons * elem; + outs[offset] = -1e4; +} + +kernel void queryCausalSeqForwardHalf( + const device half * query, + const device half * key, + constant uint & nbHeadsQuery, + constant uint & nbHeadsKey, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevQuery, + constant uint & nbNeuronsPrevKey, + constant uint & nbBatch, + constant uint & sequence, + device half * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevQuery / nbHeadsQuery; + + uint headQuery = id[0] / sequence; + uint seqK = id[0] % sequence; + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headQuery >= nbHeadsQuery || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + half tmp = 0.0; + + for (uint j=0; j= nbHeadsQuery || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + half4 tmp = 0.0; + + for (uint j=0; j= nbHeadsQuery || j >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + uint depthPrevKey = j + headKey * size; + uint depthPrevQuery = j + headQuery * size; + + half tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offset = seqK + headQuery * sequence + + nbNeurons * seqQ + sequence * nbNeurons * elem; + uint offsetKey = depthPrevKey + + nbNeuronsPrevKey * seqK + sequence * nbNeuronsPrevKey * elem; + + tmp += delta[offset] * key[offsetKey]; + } + tmp /= sqrt((half)size); + + uint offsetQuery = depthPrevQuery + + nbNeuronsPrevQuery * seqQ + sequence * nbNeuronsPrevQuery * elem; + + if (dirty) + { + query[offsetQuery] = tmp; + } + else + { + query[offsetQuery] += tmp; + } +} + +kernel void queryCausalQuerySeq4BackwardHalf( + const device half * delta, + const device half4 * key, + constant uint & nbHeadsQuery, + constant uint & nbHeadsKey, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevQuery, + constant uint & nbNeuronsPrevKey, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device half4 * query, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevQuery / nbHeadsQuery; + + uint headQuery = id[0] / (size / 4); + uint j = id[0] % (size / 4); + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headQuery >= nbHeadsQuery || j * 4 >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headKey = nbHeadsQuery == nbHeadsKey ? + headQuery : headQuery / nbHeadsKey; + uint depthPrevKey = j * 4 + headKey * size; + uint depthPrevQuery = j * 4 + headQuery * size; + + half4 tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offset = seqK + headQuery * sequence + + nbNeurons * seqQ + sequence * nbNeurons * elem; + uint offsetKey = (depthPrevKey + + nbNeuronsPrevKey * seqK + sequence * nbNeuronsPrevKey * elem) / 4; + + tmp += delta[offset] * key[offsetKey]; + } + tmp /= sqrt((half)size); + + uint offsetQuery = (depthPrevQuery + + nbNeuronsPrevQuery * seqQ + sequence * nbNeuronsPrevQuery * elem) / 4; + + if (dirty) + { + query[offsetQuery] = tmp; + } + else + { + query[offsetQuery] += tmp; + } +} + +kernel void queryCausalKeySeqBackwardHalf( + const device half * delta, + const device half * query, + constant uint & nbHeadsQuery, + constant uint & nbHeadsKey, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevQuery, + constant uint & nbNeuronsPrevKey, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device half * key, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevKey / nbHeadsKey; + + uint headKey = id[0] / size; + uint j = id[0] % size; + uint elem = id[1] / sequence; + uint seqK = id[1] % sequence; + + if (headKey >= nbHeadsKey || j >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint nbBlocksHead = nbHeadsQuery == nbHeadsKey ? + 1 : nbHeadsQuery / nbHeadsKey; + uint depthPrevKey = j + headKey * size; + + half tmp = 0.0; + for (uint blockHead=0; blockHead= nbHeadsKey || j * 4 >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint nbBlocksHead = nbHeadsQuery == nbHeadsKey ? + 1 : nbHeadsQuery / nbHeadsKey; + uint depthPrevKey = j * 4 + headKey * size; + + half4 tmp = 0.0; + for (uint blockHead=0; blockHeadbhli", [queries, rotation_matrix]) keys = torch.einsum("bhlj,lij->bhli", [keys, rotation_matrix]) - """scores = torch.matmul(queries, keys.transpose(2, 3)) * self.scale + scores = torch.matmul(queries, keys.transpose(2, 3)) * self.scale if mask is not None: scores += mask scores = torch.softmax( scores.type(torch.float32), dim=-1 ).type_as(scores) - output = torch.matmul(scores, values) + """output = torch.matmul(scores, values) output = output.transpose(1, 2).contiguous().reshape(B, L, -1) return self.wo(output), (keys, values)""" - return queries.transpose(1, 2).contiguous().reshape(B, L, -1), (keys, values) + return scores.transpose(1, 2).contiguous().reshape(B, L, -1), (keys, values) class FeedForward(torch.nn.Module): diff --git a/Tests/GrAIExamples/NLPExample.swift b/Tests/GrAIExamples/NLPExample.swift index 8e24a925..26decf00 100644 --- a/Tests/GrAIExamples/NLPExample.swift +++ b/Tests/GrAIExamples/NLPExample.swift @@ -16,7 +16,7 @@ final class NLPExample: XCTestCase let _modelPath = "TO/UPDATE" /// Prompt. - let _prompt = "I" + let _prompt = "Hello" /// Initialize test. override func setUp() @@ -34,7 +34,9 @@ final class NLPExample: XCTestCase /// - Parameters: /// - sequence: Length of the sequence. /// - hiddenDim: Dimension of neurons in the main branch. - /// - nbHeads: Number of heads (groups) of neurons. + /// - headDim: Dimension of neurons in the transformer branches. + /// - nbHeads: Number of heads (groups) of neurons for queries. + /// - nbHeadsKV: Number of heads (groups) of neurons for keys and values. /// - vocabularySize: Vocabulary size. /// - Returns: The model built. /// @@ -42,7 +44,9 @@ final class NLPExample: XCTestCase modelPath: String, sequence: Int, hiddenDim: Int, - nbHeads: Int, + headDim: Int, + nbHeadsQuery: Int, + nbHeadsKV: Int, vocabularySize: Int) -> Model { let context = ModelContext(name: "NLP", curID: 0) @@ -54,18 +58,42 @@ final class NLPExample: XCTestCase nbNeurons: hiddenDim, params: params ) - layer = FullyConnectedSeq( + var query: LayerSeq = FullyConnectedSeq( layerPrev: layer, - nbNeurons: hiddenDim, + nbNeurons: nbHeadsQuery * headDim, activation: nil, biases: false, params: params ) + query = try! RoPESeq( + layerPrev: query, + seqPositions: [Int](1...sequence), + nbHeads: nbHeadsQuery, + params: params + ) - layer = try! RoPESeq( + var key: LayerSeq = FullyConnectedSeq( layerPrev: layer, + nbNeurons: nbHeadsKV * headDim, + activation: nil, + biases: false, + params: params + ) + key = try! RoPESeq( + layerPrev: key, seqPositions: [Int](1...sequence), - nbHeads: nbHeads, + nbHeads: nbHeadsKV, + params: params + ) + + layer = try! QueryCausalSeq( + query: query, key: key, + nbHeadsQuery: nbHeadsQuery, nbHeadsKey: nbHeadsKV, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: nbHeadsQuery, params: params ) @@ -146,7 +174,9 @@ final class NLPExample: XCTestCase modelPath: _modelPath, sequence: prompt.count, hiddenDim: 4096, - nbHeads: 32, + headDim: 128, + nbHeadsQuery: 32, + nbHeadsKV: 8, vocabularySize: 32000 ) @@ -167,8 +197,15 @@ final class NLPExample: XCTestCase // Compare difference. for (elemOut, elemRef) in zip(arrayOut, arrayRef) { - let diffPercent = abs(elemOut - elemRef) / elemRef * 100.0 - XCTAssert(diffPercent < 1) + if elemRef == 0.0 + { + XCTAssert(elemOut == 0.0) + } + else + { + let diffPercent = abs(elemOut - elemRef) / elemRef * 100.0 + XCTAssert(diffPercent < 1) + } } } } diff --git a/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift b/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift index 3a349b17..e1d62089 100644 --- a/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift +++ b/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift @@ -28,8 +28,8 @@ class EmbeddingSeqMSE1DCase: XCTestCase, Input1DCase, IOCase override func setUp() { batchSize = 5 - sequence = 7 - vocabularySize = 120 + sequence = 5 + vocabularySize = 7 _ = MetalKernel.get GrAI.Opti.GPU = true diff --git a/Tests/GrAITests/NLPTests.swift b/Tests/GrAITests/NLPTests.swift index 0ad3ca97..01372740 100644 --- a/Tests/GrAITests/NLPTests.swift +++ b/Tests/GrAITests/NLPTests.swift @@ -73,6 +73,58 @@ class NLPGradTests: EmbeddingSeqMSE1DCase nbHeads: 3, params: params ) + + case "QueryCausal1": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 3, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! QueryCausalSeq( + query: layer, key: otherLayer, + nbHeadsQuery: 3, nbHeadsKey: 3, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: 3, + params: params + ) + + case "QueryCausal2": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 2 * 3, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 4 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! QueryCausalSeq( + query: layer, key: otherLayer, + nbHeadsQuery: 4, nbHeadsKey: 2, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: 4, + params: params + ) default: fatalError("Unreachable.") @@ -133,6 +185,32 @@ class NLPGradTests: EmbeddingSeqMSE1DCase let trainer = _buildTrainer("RoPE") run(trainer) } + + func testQueryCausal1CPU() throws + { + GrAI.Opti.CPU = true + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + func testQueryCausal1GPU() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + func testQueryCausal2CPU() throws + { + GrAI.Opti.CPU = true + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } + + func testQueryCausal2GPU() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -192,6 +270,58 @@ class NLPFlowTests: EmbeddingSeqMSE1DCase params: params ) + case "QueryCausal1": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 3, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! QueryCausalSeq( + query: layer, key: otherLayer, + nbHeadsQuery: 3, nbHeadsKey: 3, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: 3, + params: params + ) + + case "QueryCausal2": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 2 * 3, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 4 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! QueryCausalSeq( + query: layer, key: otherLayer, + nbHeadsQuery: 4, nbHeadsKey: 2, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: 4, + params: params + ) + default: fatalError("Unreachable.") } @@ -230,6 +360,18 @@ class NLPFlowTests: EmbeddingSeqMSE1DCase let trainer = _buildTrainer("RoPE") run(trainer) } + + func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -276,6 +418,162 @@ class NLPFlowPrecisionTests: NLPFlowTests let trainer = _buildTrainer("RoPE") run(trainer) } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer, diffThreshold: 0.002) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer, diffThreshold: 0.002) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with CPU ones through time. +// We expect to see errors ~ 1e-7 and less. +// ----------------------------------------------------------------------------- +class NLP4FlowTests: EmbeddingSeqMSE1DCase +{ + private func _buildTrainer(_ model: String) -> FlowTrainer + { + let trainer = FlowTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + func buildModel(model: String, context: ModelContext) + { + let params = GrAI.Model.Params(context: context) + + var layer: LayerSeq = EmbeddingSeq( + sequence: sequence, + vocabularySize: vocabularySize, + nbNeurons: 4, params: params + ) + + switch model + { + case "QueryCausal1": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 4, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 4, + activation: nil, + biases: false, + params: params + ) + layer = try! QueryCausalSeq( + query: layer, key: otherLayer, + nbHeadsQuery: 3, nbHeadsKey: 3, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: 3, + params: params + ) + + case "QueryCausal2": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 2 * 4, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 4 * 4, + activation: nil, + biases: false, + params: params + ) + layer = try! QueryCausalSeq( + query: layer, key: otherLayer, + nbHeadsQuery: 4, nbHeadsKey: 2, + params: params + ) + layer = try! SoftmaxSeq( + layerPrev: layer, + nbHeads: 4, + params: params + ) + + default: + fatalError("Unreachable.") + } + + var head: Layer1D = AvgPoolSeq(layerPrev: layer, params: params) + + head = try! FullyConnected( + layerPrev: head, nbNeurons: 1, + activation: LeakyReLU.str, biases: true, params: params + ) + + _ = MSE1D(layerPrev: head, params: params) + } + + func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with Float precision versus Float16 precision. +// We expect to see errors ~ 1e-4 and less. +// ----------------------------------------------------------------------------- +class NLP4FlowPrecisionTests: NLP4FlowTests +{ + private func _buildTrainer(_ model: String) -> FlowPrecisionTrainer + { + let trainer = FlowPrecisionTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer, diffThreshold: 0.002) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer, diffThreshold: 0.002) + } } // ----------------------------------------------------------------------------- @@ -330,6 +628,18 @@ class NLPFlowResetTests: NLPFlowTests let trainer = _buildTrainer("RoPE") run(trainer) } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -384,6 +694,18 @@ class NLPFlowReverseTests: NLPFlowTests let trainer = _buildTrainer("RoPE") run(trainer) } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -506,6 +828,18 @@ class NLPInferenceTests: NLPFlowTests let trainer = _buildTrainer("RoPE") run(trainer) } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -553,6 +887,18 @@ class NLPLoadTests: NLPFlowTests let trainer = _buildTrainer("RoPE") run(trainer) } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -644,4 +990,16 @@ class NLPTransformTests: NLPFlowTests let trainer = _buildTrainer("RoPE") run(trainer) } + + override func testQueryCausal1() throws + { + let trainer = _buildTrainer("QueryCausal1") + run(trainer) + } + + override func testQueryCausal2() throws + { + let trainer = _buildTrainer("QueryCausal2") + run(trainer) + } }