diff --git a/CHANGELOG.md b/CHANGELOG.md index 84566f60..da68e650 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ## [unreleased] +✨ **layer_seq:** ValueCausalSeq ([126](https://github.com/owkin/GrAIdient/pull/126))\ ✨ **layer_seq:** QueryCausalSeq ([125](https://github.com/owkin/GrAIdient/pull/125))\ ✨ **layer_seq:** RoPESeq ([124](https://github.com/owkin/GrAIdient/pull/124))\ ✨ **layer_seq:** RMSNormSeq ([123](https://github.com/owkin/GrAIdient/pull/123))\ diff --git a/Sources/GrAIdient/LayerSeq/QuerySeq.swift b/Sources/GrAIdient/LayerSeq/QuerySeq.swift index 012fae53..31148ce1 100644 --- a/Sources/GrAIdient/LayerSeq/QuerySeq.swift +++ b/Sources/GrAIdient/LayerSeq/QuerySeq.swift @@ -1236,20 +1236,20 @@ public class QueryCausalSeq: LayerMergeSeq let query = (_layersPrev[0] as! LayerSeq).neurons! let key = (_layersPrev[1] as! LayerSeq).neurons! + let size = (_layersPrev[0] as! LayerSeq).nbNeurons / _nbHeadsQuery + let nbBlocksHead = _nbHeadsQuery / _nbHeadsKey for batch in 0.., + inPlace: Bool) -> Layer + { + let context = ModelContext(name: "", curID: 0) + let params = GrAI.Model.Params(context: context) + params.context.curID = id + + var layersPrev = [LayerSeq]() + for idPrev in _idsPrev + { + layersPrev.append(mapping[idPrev] as! LayerSeq) + } + + let layer = try! ValueCausalSeq( + value: layersPrev[0], score: layersPrev[1], + nbHeadsValue: _nbHeadsValue, + nbHeadsScore: _nbHeadsScore, + params: params + ) + return layer + } + + /// + /// Apply the forward pass of the Gradient Checking in CPU execution context. + /// + /// Throw an error if batch size is greater than the first batch size. + /// + public override func forwardGCCPU() throws + { + try checkStateCPU(batchSize: batchSize) + + let (nbSameElems, layersIndex, nbElems) = getMergedGraph() + + var nbGC = nbSameElems + for nbElemsTmp in nbElems + { + nbGC += nbElemsTmp + } + + for seq in 0..= nbHeadsScore || j >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + uint depthScore = j + headScore * size; + uint depthValue = j + headValue * size; + + float tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offsetValue = depthValue + + nbNeuronsPrevValue * seqK + sequence * nbNeuronsPrevValue * elem; + uint offsetScore = seqK + headScore * sequence + + nbNeuronsPrevScore * seqQ + sequence * nbNeuronsPrevScore * elem; + + tmp += value[offsetValue] * score[offsetScore]; + } + + uint offset = depthScore + nbNeurons * seqQ + sequence * nbNeurons * elem; + outs[offset] = tmp; +} + +kernel void valueCausalSeq4ForwardFloat( + const device float4 * value, + const device float * score, + constant uint & nbHeadsValue, + constant uint & nbHeadsScore, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevValue, + constant uint & nbNeuronsPrevScore, + constant uint & nbBatch, + constant uint & sequence, + device float4 * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevValue / nbHeadsValue; + uint nbBlocksHead = nbHeadsScore / nbHeadsValue; + + uint headScore = id[0] / (size / 4); + uint j = id[0] % (size / 4); + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headScore >= nbHeadsScore || j >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + uint depthScore = j * 4 + headScore * size; + uint depthValue = j * 4 + headValue * size; + + float4 tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offsetValue = (depthValue + + nbNeuronsPrevValue * seqK + + sequence * nbNeuronsPrevValue * elem) / 4; + uint offsetScore = seqK + headScore * sequence + + nbNeuronsPrevScore * seqQ + sequence * nbNeuronsPrevScore * elem; + + tmp += value[offsetValue] * score[offsetScore]; + } + + uint offset = (depthScore + + nbNeurons * seqQ + sequence * nbNeurons * elem) / 4; + outs[offset] = tmp; +} + +kernel void valueCausalValueSeqBackwardFloat( + const device float * delta, + const device float * score, + constant uint & nbHeadsValue, + constant uint & nbHeadsScore, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevValue, + constant uint & nbNeuronsPrevScore, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device float * value, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevValue / nbHeadsValue; + uint nbBlocksHead = nbHeadsScore / nbHeadsValue; + + uint headValue = id[0] / size; + uint j = id[0] % size; + uint elem = id[1] / sequence; + uint seqK = id[1] % sequence; + + if (headValue >= nbHeadsValue || j >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint depthValue = j + headValue * size; + + float tmp = 0.0; + for (uint blockHead=0; blockHead= nbHeadsValue || j * 4 >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint depthValue = j + headValue * size; + + float4 tmp = 0.0; + for (uint blockHead=0; blockHead= nbHeadsScore || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + float tmp = 0.0; + for (uint j=0; j= nbHeadsScore || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + float4 tmp = 0.0; + for (uint j=0; j= nbHeadsScore || j >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + uint depthScore = j + headScore * size; + uint depthValue = j + headValue * size; + + half tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offsetValue = depthValue + + nbNeuronsPrevValue * seqK + sequence * nbNeuronsPrevValue * elem; + uint offsetScore = seqK + headScore * sequence + + nbNeuronsPrevScore * seqQ + sequence * nbNeuronsPrevScore * elem; + + tmp += value[offsetValue] * score[offsetScore]; + } + + uint offset = depthScore + nbNeurons * seqQ + sequence * nbNeurons * elem; + outs[offset] = tmp; +} + +kernel void valueCausalSeq4ForwardHalf( + const device half4 * value, + const device half * score, + constant uint & nbHeadsValue, + constant uint & nbHeadsScore, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevValue, + constant uint & nbNeuronsPrevScore, + constant uint & nbBatch, + constant uint & sequence, + device half4 * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevValue / nbHeadsValue; + uint nbBlocksHead = nbHeadsScore / nbHeadsValue; + + uint headScore = id[0] / (size / 4); + uint j = id[0] % (size / 4); + uint elem = id[1] / sequence; + uint seqQ = id[1] % sequence; + + if (headScore >= nbHeadsScore || j >= size || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + uint depthScore = j * 4 + headScore * size; + uint depthValue = j * 4 + headValue * size; + + half4 tmp = 0.0; + for (uint seqK=0; seqK<=seqQ; seqK++) + { + uint offsetValue = (depthValue + + nbNeuronsPrevValue * seqK + + sequence * nbNeuronsPrevValue * elem) / 4; + uint offsetScore = seqK + headScore * sequence + + nbNeuronsPrevScore * seqQ + sequence * nbNeuronsPrevScore * elem; + + tmp += value[offsetValue] * score[offsetScore]; + } + + uint offset = (depthScore + + nbNeurons * seqQ + sequence * nbNeurons * elem) / 4; + outs[offset] = tmp; +} + +kernel void valueCausalValueSeqBackwardHalf( + const device half * delta, + const device half * score, + constant uint & nbHeadsValue, + constant uint & nbHeadsScore, + constant uint & nbNeurons, + constant uint & nbNeuronsPrevValue, + constant uint & nbNeuronsPrevScore, + constant uint & nbBatch, + constant uint & sequence, + constant uint & dirty, + device half * value, + uint2 id [[ thread_position_in_grid ]]) +{ + uint size = nbNeuronsPrevValue / nbHeadsValue; + uint nbBlocksHead = nbHeadsScore / nbHeadsValue; + + uint headValue = id[0] / size; + uint j = id[0] % size; + uint elem = id[1] / sequence; + uint seqK = id[1] % sequence; + + if (headValue >= nbHeadsValue || j >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint depthValue = j + headValue * size; + + half tmp = 0.0; + for (uint blockHead=0; blockHead= nbHeadsValue || j * 4 >= size || + elem >= nbBatch || seqK >= sequence) + { + return ; + } + + uint depthValue = j + headValue * size; + + half4 tmp = 0.0; + for (uint blockHead=0; blockHead= nbHeadsScore || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + half tmp = 0.0; + for (uint j=0; j= nbHeadsScore || seqK > seqQ || + elem >= nbBatch || seqQ >= sequence) + { + return ; + } + + uint headValue = headScore / nbBlocksHead; + + half4 tmp = 0.0; + for (uint j=0; j 1 + { + print(diffPercent) + } XCTAssert(diffPercent < 1) } } diff --git a/Tests/GrAITests/Layer1DTests.swift b/Tests/GrAITests/Layer1DTests.swift index a2dd30d6..6d360574 100644 --- a/Tests/GrAITests/Layer1DTests.swift +++ b/Tests/GrAITests/Layer1DTests.swift @@ -593,7 +593,7 @@ class Layer1DFlowPrecisionTests: Layer1DFlowTests override func testActivation() throws { let trainer = _buildTrainer("Activation") - run(trainer) + run(trainer, diffThreshold: 0.002) } override func testSelectNeurons() throws diff --git a/Tests/GrAITests/LayerSeqTests.swift b/Tests/GrAITests/LayerSeqTests.swift index 35d0f408..8598d8e6 100644 --- a/Tests/GrAITests/LayerSeqTests.swift +++ b/Tests/GrAITests/LayerSeqTests.swift @@ -881,7 +881,7 @@ class LayerSeqFlowPrecisionTests: LayerSeqFlowTests override func testSoftmaxSeq() throws { let trainer = _buildTrainer("Softmax") - run(trainer, diffThreshold: 0.002) + run(trainer, diffThreshold: 0.005) } override func testValueSeq() throws @@ -1339,7 +1339,7 @@ class LayerSeq4FlowPrecisionTests: LayerSeq4FlowTests override func testLayerNormSeq() throws { let trainer = _buildTrainer("LayerNorm") - run(trainer, diffThreshold: 0.002) + run(trainer, diffThreshold: 0.005) } override func testQuerySeq() throws diff --git a/Tests/GrAITests/NLPTests.swift b/Tests/GrAITests/NLPTests.swift index 01372740..41f22b32 100644 --- a/Tests/GrAITests/NLPTests.swift +++ b/Tests/GrAITests/NLPTests.swift @@ -126,6 +126,48 @@ class NLPGradTests: EmbeddingSeqMSE1DCase params: params ) + case "ValueCausal1": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * sequence, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! ValueCausalSeq( + value: layer, score: otherLayer, + nbHeadsValue: 3, nbHeadsScore: 3, + params: params + ) + + case "ValueCausal2": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 4 * sequence, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 2 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! ValueCausalSeq( + value: layer, score: otherLayer, + nbHeadsValue: 2, nbHeadsScore: 4, + params: params + ) + default: fatalError("Unreachable.") } @@ -211,6 +253,32 @@ class NLPGradTests: EmbeddingSeqMSE1DCase let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + func testValueCausal1CPU() throws + { + GrAI.Opti.CPU = true + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + func testValueCausal1GPU() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + func testValueCausal2CPU() throws + { + GrAI.Opti.CPU = true + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } + + func testValueCausal2GPU() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -322,6 +390,48 @@ class NLPFlowTests: EmbeddingSeqMSE1DCase params: params ) + case "ValueCausal1": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * sequence, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! ValueCausalSeq( + value: layer, score: otherLayer, + nbHeadsValue: 3, nbHeadsScore: 3, + params: params + ) + + case "ValueCausal2": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 4 * sequence, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 2 * 3, + activation: nil, + biases: false, + params: params + ) + layer = try! ValueCausalSeq( + value: layer, score: otherLayer, + nbHeadsValue: 2, nbHeadsScore: 4, + params: params + ) + default: fatalError("Unreachable.") } @@ -372,6 +482,18 @@ class NLPFlowTests: EmbeddingSeqMSE1DCase let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -430,6 +552,18 @@ class NLPFlowPrecisionTests: NLPFlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer, diffThreshold: 0.002) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -516,6 +650,48 @@ class NLP4FlowTests: EmbeddingSeqMSE1DCase params: params ) + case "ValueCausal1": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * sequence, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 3 * 4, + activation: nil, + biases: false, + params: params + ) + layer = try! ValueCausalSeq( + value: layer, score: otherLayer, + nbHeadsValue: 3, nbHeadsScore: 3, + params: params + ) + + case "ValueCausal2": + let otherLayer: LayerSeq = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 4 * sequence, + activation: nil, + biases: false, + params: params + ) + layer = FullyConnectedSeq( + layerPrev: layer, + nbNeurons: 2 * 4, + activation: nil, + biases: false, + params: params + ) + layer = try! ValueCausalSeq( + value: layer, score: otherLayer, + nbHeadsValue: 2, nbHeadsScore: 4, + params: params + ) + default: fatalError("Unreachable.") } @@ -541,6 +717,18 @@ class NLP4FlowTests: EmbeddingSeqMSE1DCase let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -574,6 +762,18 @@ class NLP4FlowPrecisionTests: NLP4FlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer, diffThreshold: 0.002) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer, diffThreshold: 0.002) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer, diffThreshold: 0.002) + } } // ----------------------------------------------------------------------------- @@ -640,6 +840,18 @@ class NLPFlowResetTests: NLPFlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -706,6 +918,18 @@ class NLPFlowReverseTests: NLPFlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -840,6 +1064,18 @@ class NLPInferenceTests: NLPFlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -899,6 +1135,18 @@ class NLPLoadTests: NLPFlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } } // ----------------------------------------------------------------------------- @@ -1002,4 +1250,16 @@ class NLPTransformTests: NLPFlowTests let trainer = _buildTrainer("QueryCausal2") run(trainer) } + + override func testValueCausal1() throws + { + let trainer = _buildTrainer("ValueCausal1") + run(trainer) + } + + override func testValueCausal2() throws + { + let trainer = _buildTrainer("ValueCausal2") + run(trainer) + } }