From d97e5200afd97d7fce7aec7e5bf668c145fcfbb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Reboud?= Date: Fri, 14 Jun 2024 09:30:20 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(layer=5Fseq):=20EmbeddingSeq?= =?UTF-8?q?=20(#122)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 1 + Sources/GrAIdient/Layer1D/Constant1D.swift | 2 +- Sources/GrAIdient/Layer2D/Constant2D.swift | 2 +- Sources/GrAIdient/Layer2D/VQ2D.swift | 2 +- Sources/GrAIdient/LayerSeq/ConstantSeq.swift | 2 +- Sources/GrAIdient/LayerSeq/EmbeddingSeq.swift | 767 ++++++++++++++++++ .../LayerSeq/FullyConnectedPatch.swift | 2 +- .../LayerSeq/FullyConnectedSeq.swift | 2 +- Sources/GrAIdient/LayerSeq/VQSeq.swift | 2 +- .../Metal/Kernel/EmbeddingSeqFloat.metal | 155 ++++ .../Metal/Kernel/EmbeddingSeqHalf.metal | 155 ++++ Sources/GrAIdient/Metal/MetalConfig.swift | 10 + Sources/GrAIdient/Utils/Serialization.swift | 1 + Tests/GrAIExamples/Base/Utils.swift | 6 + .../GrAIExamples/Base/python_lib/__init__.py | 10 + .../Base/python_lib/{llm => nlp}/__init__.py | 0 .../Base/python_lib/{llm => nlp}/generate.py | 109 ++- .../Base/python_lib/{llm => nlp}/model.py | 241 +++--- .../Base/python_lib/{llm => nlp}/tokenizer.py | 0 Tests/GrAIExamples/Base/python_lib/weight.py | 55 +- Tests/GrAIExamples/NLPExample.swift | 125 +++ .../Base/InputSeq/EmbeddingSeqMSE1DCase.swift | 189 +++++ Tests/GrAITests/NLPTests.swift | 453 +++++++++++ 23 files changed, 2146 insertions(+), 145 deletions(-) create mode 100644 Sources/GrAIdient/LayerSeq/EmbeddingSeq.swift create mode 100644 Sources/GrAIdient/Metal/Kernel/EmbeddingSeqFloat.metal create mode 100644 Sources/GrAIdient/Metal/Kernel/EmbeddingSeqHalf.metal rename Tests/GrAIExamples/Base/python_lib/{llm => nlp}/__init__.py (100%) rename Tests/GrAIExamples/Base/python_lib/{llm => nlp}/generate.py (52%) rename Tests/GrAIExamples/Base/python_lib/{llm => nlp}/model.py (65%) rename Tests/GrAIExamples/Base/python_lib/{llm => nlp}/tokenizer.py (100%) create mode 100644 Tests/GrAIExamples/NLPExample.swift create mode 100644 Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift create mode 100644 Tests/GrAITests/NLPTests.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index 54a29551..242cecbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ## [unreleased] +✨ **layer_seq:** EmbeddingSeq ([122](https://github.com/owkin/GrAIdient/pull/122))\ 🚀 **perf:** use half in Metal kernels ([121](https://github.com/owkin/GrAIdient/pull/121))\ 🔨 **refactor:** handle float16 along float on GPU ([#120](https://github.com/owkin/GrAIdient/pull/120))\ 🚀 **perf:** copy & generate weights faster ([119](https://github.com/owkin/GrAIdient/pull/119))\ diff --git a/Sources/GrAIdient/Layer1D/Constant1D.swift b/Sources/GrAIdient/Layer1D/Constant1D.swift index 8976a21f..3d0fb69f 100644 --- a/Sources/GrAIdient/Layer1D/Constant1D.swift +++ b/Sources/GrAIdient/Layer1D/Constant1D.swift @@ -21,7 +21,7 @@ public class Constant1D: Layer1D, LayerUpdate var _wBuffers: IWeightBuffers! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, nbNeurons). /// var _wDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/Layer2D/Constant2D.swift b/Sources/GrAIdient/Layer2D/Constant2D.swift index 96d80aee..8c5829cb 100644 --- a/Sources/GrAIdient/Layer2D/Constant2D.swift +++ b/Sources/GrAIdient/Layer2D/Constant2D.swift @@ -21,7 +21,7 @@ public class Constant2D: Layer2D, LayerResize, LayerUpdate var _wBuffers: IWeightBuffers! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, nbChannels). /// var _wDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/Layer2D/VQ2D.swift b/Sources/GrAIdient/Layer2D/VQ2D.swift index 80449635..9dde168f 100644 --- a/Sources/GrAIdient/Layer2D/VQ2D.swift +++ b/Sources/GrAIdient/Layer2D/VQ2D.swift @@ -59,7 +59,7 @@ public class VQ2D: LayerOutput2D, LayerWeightInit var _wBuffers: IWeightBuffers! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, K, nbChannels). /// var _wDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/LayerSeq/ConstantSeq.swift b/Sources/GrAIdient/LayerSeq/ConstantSeq.swift index f8796ecb..afc34e4d 100644 --- a/Sources/GrAIdient/LayerSeq/ConstantSeq.swift +++ b/Sources/GrAIdient/LayerSeq/ConstantSeq.swift @@ -505,7 +505,7 @@ public class Constant2Seq: LayerSeq, LayerUpdate var _wBuffers: IWeightBuffers! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, sequence, nbNeurons). /// var _wDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/LayerSeq/EmbeddingSeq.swift b/Sources/GrAIdient/LayerSeq/EmbeddingSeq.swift new file mode 100644 index 00000000..59472a17 --- /dev/null +++ b/Sources/GrAIdient/LayerSeq/EmbeddingSeq.swift @@ -0,0 +1,767 @@ +// +// EmbeddingSeq.swift +// GrAIdient +// +// Created by Jean-François Reboud on 04/06/2024. +// + +import Foundation + +/// Input layer with a sequential shape neural structure and weights. +public class EmbeddingSeq: LayerSeq, LayerWeightInit +{ + /// Size of vocabulary. + public var vocabularySize: Int + + /// + /// Input buffer. + /// Shape ~ (batch, seq). + /// + public var ins: MetalBuffer! = nil + + /// + /// Grid of weights. + /// Shape ~ (vocabularySize, nbNeurons). + /// + var _wArrays: WeightGrids! = nil + + /// + /// Buffer of weights. + /// Shape ~ (vocabularySize, nbNeurons). + /// + var _wBuffers: IWeightBuffers! = nil + + /// + /// Buffer of gradients per sample. + /// Shape ~ (batch, vocabularySize, nbNeurons). + /// + var _wDeltaWeights: FloatBuffer! = nil + + /// Whether to compute weights' gradients or not. + public var computeDeltaWeights: Bool = true + + /// Whether gradients of weights must be accumulated or not. + public var accumulateDeltaWeights: Bool = false + + /// Cache for weights before calling `initKernel` API. + var _weightsList = [Float]() + + /// Weights in the CPU execution context. + public var weightsCPU: [Float] + { + get { + if _wArrays == nil + { + return _weightsList + } + + var weightsTmp = [Float]() + for index in 0.., + inPlace: Bool) -> Layer + { + if idPrev > -1 + { + fatalError("EmbeddingSeq must be the first layer.") + } + + let context = ModelContext(name: "", curID: 0) + let params = GrAI.Model.Params(context: context) + params.context.curID = id + + let layer = EmbeddingSeq( + sequence: sequence, + vocabularySize: vocabularySize, + nbNeurons: nbNeurons, + params: params + ) + + if inPlace + { + layer._wArrays = _wArrays + layer._wBuffers = _wBuffers + } + else + { + if GrAI.Opti.GPU + { + layer.weightsGPU = weightsGPU + } + else + { + layer.weightsCPU = weightsCPU + } + } + return layer + } + + /// + /// Clean state resources in the CPU execution context. + /// + /// We first clean the neurons' state (forward and backward). + /// We do not clean weights and biases but must reset their delta (dependent on batch size) and + /// momentum state. + /// + public override func resetKernelCPU() + { + super.resetKernelCPU() + _wArrays?.reset() + ins = nil + } + + /// + /// Clean state resources in the GPU execution context. + /// + /// We first clean the neurons' state (forward and backward). + /// We do not clean weights and biases but must reset their delta (dependent on batch size) and + /// momentum state. + /// + public override func resetKernelGPU() + { + super.resetKernelGPU() + + ins = nil + _wDeltaWeights = nil + _wBuffers?.reset() + } + + /// + /// Initialize weights in the CPU execution context. + /// + /// Their momentum and delta state are also reset. + /// + public func initWeightsCPU() + { + if _weightsList.count == 0 + { + _weightsList = generateWeightsList() + } + + _wArrays = WeightGrids(width: nbNeurons, height: vocabularySize) + + for index in 0..( + batchSize * sequence, deviceID: deviceID + ) + } + else if batchSize <= 0 || batchSize > ins.nbElems / sequence + { + throw LayerError.BatchSize + } + + var dataFlat = data.flatMap { $0.map { Int32($0)} } + let ins_s = ins as! MetalSharedBuffer + copyArrayToBuffer( + array: &dataFlat, + buffer: ins_s.buffer, + start: 0, + nbElems: batchSize * sequence + ) + } + + /// + /// Check and setup input in the GPU execution context. + /// + /// Throw an error if data size is not coherent. + /// + /// - Parameters: + /// - data: The input data. + /// - batchSize: The batch size of data. + /// - sequence: Length of the sequence. + /// + public func checkInputGPU( + _ data: [[Int]], + batchSize: Int, + sequence: Int) throws + { + if data.count != batchSize || data.first!.count != sequence + { + throw LayerError.DataSize + } + + if ins == nil + { + ins = MetalPrivateBuffer( + batchSize * sequence, deviceID: deviceID + ) + } + else if batchSize <= 0 || batchSize > ins.nbElems / sequence + { + throw LayerError.BatchSize + } + + // Wait for previous loop to end to avoid race condition. + _ = ins.download() + + var dataFlat = data.flatMap { $0.map { Int32($0)} } + let ins_s = ins as! MetalPrivateBuffer + copyArrayToBuffer( + array: &dataFlat, + buffer: ins_s.shared.buffer, + start: 0, + nbElems: batchSize * sequence + ) + ins.upload() + } + + /// + /// API to set data in the CPU execution context. + /// + /// Throw an error if data size is not coherent. + /// + /// - Parameters: + /// - data: The data to set. + /// - batchSize: The batch size of data. + /// - sequence: Length of the sequence. + /// + public func setDataCPU( + _ data: [[Int]], + batchSize: Int, + sequence: Int) throws + { + try checkInputCPU( + data, + batchSize: batchSize, + sequence: sequence + ) + } + + /// + /// API to set data in the GPU execution context. + /// + /// Throw an error if data size is not coherent. + /// + /// - Parameters: + /// - data: The data to set. + /// - batchSize: The batch size of data. + /// - sequence: Length of the sequence. + /// + public func setDataGPU( + _ data: [[Int]], + batchSize: Int, + sequence: Int) throws + { + try checkInputGPU( + data, + batchSize: batchSize, + sequence: sequence + ) + } + + /// + /// Initialize state resources in the GPU execution context. + /// + /// We initialize the neurons' forward state. + /// We initialize the weights and biases' delta. + /// + public override func checkStateForwardGPU(batchSize: Int) throws + { + try super.checkStateForwardGPU(batchSize: batchSize) + + if computeDeltaWeights && + GrAI.Gradient.sample && _wDeltaWeights == nil + { + _wDeltaWeights = FloatBuffer(nbElems: + batchSize * vocabularySize * nbNeurons, deviceID: deviceID + ) + } + } + + /// + /// Apply the forward pass of the Gradient Checking in CPU execution context. + /// + /// Throw an error if batch size is greater than the first batch size. + /// + public override func forwardGCCPU() throws + { + try checkStateCPU(batchSize: batchSize) + + let newGC = 2 * nbLearnedGC + for seq in 0..).buffer + + for batch in 0..).buffer + + for elem in 0..).buffer + + if !accumulateDeltaWeights + { + for index in 0..= vocabularySize + { + fatalError("Index \(index) is out of range.") + } + for depth in 0.. [IWeightArrays] + { + return [_wArrays] + } + + /// Get the weights in the GPU execution context. + public func collectWeightsGPU() -> [IWeightBuffers] + { + return [_wBuffers] + } +} diff --git a/Sources/GrAIdient/LayerSeq/FullyConnectedPatch.swift b/Sources/GrAIdient/LayerSeq/FullyConnectedPatch.swift index 69fd40bb..c9bf8ba5 100644 --- a/Sources/GrAIdient/LayerSeq/FullyConnectedPatch.swift +++ b/Sources/GrAIdient/LayerSeq/FullyConnectedPatch.swift @@ -47,7 +47,7 @@ public class FullyConnectedPatch: ActivationSeq, /// var _wDeltaWeights: FloatBuffer! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, nbNeurons). /// var _bDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/LayerSeq/FullyConnectedSeq.swift b/Sources/GrAIdient/LayerSeq/FullyConnectedSeq.swift index c959b30b..e6d4c1cf 100644 --- a/Sources/GrAIdient/LayerSeq/FullyConnectedSeq.swift +++ b/Sources/GrAIdient/LayerSeq/FullyConnectedSeq.swift @@ -39,7 +39,7 @@ public class FullyConnectedSeq: ActivationSeq, /// var _wDeltaWeights: FloatBuffer! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, nbNeurons). /// var _bDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/LayerSeq/VQSeq.swift b/Sources/GrAIdient/LayerSeq/VQSeq.swift index 669fbc43..ab116b38 100644 --- a/Sources/GrAIdient/LayerSeq/VQSeq.swift +++ b/Sources/GrAIdient/LayerSeq/VQSeq.swift @@ -43,7 +43,7 @@ public class VQSeq: LayerSeq, LayerWeightInit var _wBuffers: IWeightBuffers! = nil /// - /// Buffer of gradients per sample for biases. + /// Buffer of gradients per sample. /// Shape ~ (batch, K, nbNeurons). /// var _wDeltaWeights: FloatBuffer! = nil diff --git a/Sources/GrAIdient/Metal/Kernel/EmbeddingSeqFloat.metal b/Sources/GrAIdient/Metal/Kernel/EmbeddingSeqFloat.metal new file mode 100644 index 00000000..3892c780 --- /dev/null +++ b/Sources/GrAIdient/Metal/Kernel/EmbeddingSeqFloat.metal @@ -0,0 +1,155 @@ +// +// EmbeddingSeqFloat.metal +// GrAIdient +// +// Created by Jean-François Reboud on 10/06/2024. +// + +#include +using namespace metal; + +kernel void embeddingSeqForwardFloat( + const device int * ins, + const device float * weights, + constant uint * pNbNeurons, + constant uint * pNbBatch, + constant uint * pSequence, + device float * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint nbNeurons; + uint nbBatch; + uint sequence; + + if (pNbNeurons && pNbBatch && pSequence && + weights && ins && outs) + { + nbNeurons = *pNbNeurons; + nbBatch = *pNbBatch; + sequence = *pSequence; + } + else + return ; + + uint elem = id[1]; + uint seq = id[0]; + + if (seq >= sequence || elem >= nbBatch) + { + return ; + } + + int index = ins[seq + sequence * elem]; + for (uint depth=0; depth= nbNeurons || embedding >= vocabularySize) + { + return ; + } + + float sum = 0.0; + for (uint elem=0; elem= nbNeurons || elem * embedding >= nbBatch * vocabularySize) + { + return ; + } + + float sum = 0.0; + for (uint seq=0; seq +using namespace metal; + +kernel void embeddingSeqForwardHalf( + const device int * ins, + const device half * weights, + constant uint * pNbNeurons, + constant uint * pNbBatch, + constant uint * pSequence, + device half * outs, + uint2 id [[ thread_position_in_grid ]]) +{ + uint nbNeurons; + uint nbBatch; + uint sequence; + + if (pNbNeurons && pNbBatch && pSequence && + weights && ins && outs) + { + nbNeurons = *pNbNeurons; + nbBatch = *pNbBatch; + sequence = *pSequence; + } + else + return ; + + uint elem = id[1]; + uint seq = id[0]; + + if (seq >= sequence || elem >= nbBatch) + { + return ; + } + + int index = ins[seq + sequence * elem]; + for (uint depth=0; depth= nbNeurons || embedding >= vocabularySize) + { + return ; + } + + half sum = 0.0; + for (uint elem=0; elem= nbNeurons || elem * embedding >= nbBatch * vocabularySize) + { + return ; + } + + half sum = 0.0; + for (uint seq=0; seq Generator[torch.Tensor, None, None]: """ Generate text based on the given prompt and model. @@ -17,7 +18,7 @@ def generate_with_cache( ---------- prompt: torch.Tensor The input prompt. - model: LLM + model: Transformer The model to use for generation. temp: float The temperature for sampling. If temp is 0, use max sampling. @@ -48,7 +49,7 @@ def sample(logits: torch.Tensor) -> torch.Tensor: def generate( prompt: str, - model: LLM, + model: Transformer, tokenizer: Tokenizer, temp: float, max_tokens: int @@ -97,26 +98,94 @@ def generate( return -if __name__ == "__main__": - model_path = Path("TO_MODIFY/mistral/weights/mistral-7B-v0.1") - state = torch.load(str(model_path / "consolidated.00.pth")) - tokenizer = Tokenizer(str(model_path / "tokenizer.model")) +def generate_main( + prompt: str, + model_path: str +) -> np.ndarray: + """ + Generate text based on the given prompt and model. + + Parameters + ---------- + prompt: torch.Tensor + The input prompt. + model_path: str + Path to the model on the disk. + """ + state = torch.load(str(Path(model_path) / "consolidated.00.pth")) + tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open(Path(model_path) / "params.json", "r") as f: config = json.loads(f.read()) config.pop("sliding_window", None) config.pop("model_type", None) - quantization = config.pop("quantization", None) - model_args = ModelArgs(**config) + model_args = TransformerArgs(**config) - model = LLM(model_args) + model = Transformer(model_args) model.load_state_dict(state) model.to("mps") - generate( - "Hello, what is your name?", - model, - tokenizer, - 0.7, - 200 + prompt = torch.tensor( + tokenizer.encode(prompt), dtype=torch.long, device="mps" + ) + out, _ = model(prompt) + return out.detach().cpu().numpy().flatten() + """generate( + prompt=prompt, + model=model, + tokenizer=tokenizer, + temp=0.7, + max_tokens=200 + )""" + + +def encode( + prompt: str, + model_path: str +) -> List[int]: + """ + Encode text. + + Parameters + ---------- + prompt: torch.Tensor + The input prompt. + model_path: str + Path to the model on the disk. + """ + tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) + return tokenizer.encode(prompt) + + +def decode( + prompt: List[int], + model_path: str +) -> str: + """ + Decode text. + + Parameters + ---------- + prompt: torch.Tensor + The input prompt. + model_path: str + Path to the model on the disk. + """ + tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) + return tokenizer.decode(prompt) + + +if __name__ == "__main__": + model_path = "" + prompt = encode( + prompt="Hello, what is your name?", + model_path=model_path + ) + prompt = decode( + prompt=prompt, + model_path=model_path + ) + generate_main( + prompt="Hello, what is your name?", + model_path=model_path ) diff --git a/Tests/GrAIExamples/Base/python_lib/llm/model.py b/Tests/GrAIExamples/Base/python_lib/nlp/model.py similarity index 65% rename from Tests/GrAIExamples/Base/python_lib/llm/model.py rename to Tests/GrAIExamples/Base/python_lib/nlp/model.py index 311243b2..498c5f98 100644 --- a/Tests/GrAIExamples/Base/python_lib/llm/model.py +++ b/Tests/GrAIExamples/Base/python_lib/nlp/model.py @@ -4,7 +4,31 @@ @dataclass -class ModelArgs: +class TransformerArgs: + """ + Transformer parameters. + + Parameters + ---------- + dim: int + Base hidden dimension. + n_layers: int + Number of Transformer blocks. + head_dim: + Hidden dimension of each attention head. + hidden_dim: + Hidden dimension of the feed forward blocks. + n_heads: int + Number of heads for the queries. + n_kv_heads: int + Number of heads for keys and values. + norm_eps: float + Used to avoid division by 0 during normalization. + vocab_size: int + Vocabulary size. + rope_theta: float + Coefficient used to initialize rotation matrix. + """ dim: int n_layers: int head_dim: int @@ -16,81 +40,6 @@ class ModelArgs: rope_theta: float = 10000 -def get_rotary_matrix1( - context_len: int, embedding_dim: int -) -> torch.Tensor: - """ - Generate the rotary matrix for RoPE. - - Parameters - ---------- - context_len: int - The context length. - embedding_dim: int - Embedding dimension. - - Returns - ------- - R: torch.Tensor - The rotary matrix of dimension - (context_len, embedding_dim, embedding_dim). - """ - R = torch.zeros( - (context_len, embedding_dim, embedding_dim), - requires_grad=False - ) - positions = torch.arange(1, context_len+1).unsqueeze(1) - # Create matrix theta (shape: context_len, embedding_dim // 2). - slice_i = torch.arange(0, embedding_dim // 2) - theta = 10000. ** (-2.0 * (slice_i.float()) / embedding_dim) - m_theta = positions * theta - # Create sin and cos values. - cos_values = torch.cos(m_theta) - sin_values = torch.sin(m_theta) - # Populate the rotary matrix R using 2D slicing. - R[:, 2*slice_i, 2*slice_i] = cos_values - R[:, 2*slice_i, 2*slice_i+1] = -sin_values - R[:, 2*slice_i+1, 2*slice_i] = sin_values - R[:, 2*slice_i+1, 2*slice_i+1] = cos_values - return R - - -def get_rotary_matrix2( - context_offset: int, embedding_dim: int -) -> torch.Tensor: - """ - Generate the rotary matrix for RoPE. - - Parameters - ---------- - context_offset: int - The context offset. - embedding_dim: int - Embedding dimension. - - Returns - ------- - R: torch.Tensor - The rotary matrix of dimension - (1, embedding_dim, embedding_dim). - """ - R = torch.zeros((1, embedding_dim, embedding_dim), requires_grad=False) - positions = torch.tensor([context_offset + 1]).unsqueeze(1) - # Create matrix theta (shape: 1, embedding_dim // 2). - slice_i = torch.arange(0, embedding_dim // 2) - theta = 10000. ** (-2.0 * (slice_i.float()) / embedding_dim) - m_theta = positions * theta - # Create sin and cos values. - cos_values = torch.cos(m_theta) - sin_values = torch.sin(m_theta) - # Populate the rotary matrix R using 2D slicing. - R[:, 2*slice_i, 2*slice_i] = cos_values - R[:, 2*slice_i, 2*slice_i+1] = -sin_values - R[:, 2*slice_i+1, 2*slice_i] = sin_values - R[:, 2*slice_i+1, 2*slice_i+1] = cos_values - return R - - class RMSNorm(torch.nn.Module): """ Root mean squared norm. @@ -135,11 +84,11 @@ class Attention(torch.nn.Module): Parameters ---------- - args: ModelArgs + args: TransformerArgs Model parameters. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.args = args @@ -189,9 +138,57 @@ def create_additive_causal_mask( mask = mask.type(dtype) * -1e9 return mask + @staticmethod + def create_rotation_matrix( + positions: torch.Tensor, + embedding_dim: int, + rope_theta: float, + device: torch.device, + ) -> torch.Tensor: + """ + Generate the rotary matrix for RoPE. + + Parameters + ---------- + positions: torch.Tensor + Tensor containing the different indices of the sequential axis + to take into account for positional encoding. + embedding_dim: int + Embedding dimension. + rope_theta: float + RoPE theta. + device: torch.device + Device on which the matrix is to be loaded. + + Returns + ------- + R: torch.Tensor + The rotary matrix of dimension + (len(positions), embedding_dim, embedding_dim). + """ + R = torch.zeros( + (len(positions), embedding_dim, embedding_dim), + requires_grad=False, + device=device, + ) + + slice_i = torch.arange(0, embedding_dim // 2, device=device) + theta = rope_theta ** (-2.0 * (slice_i.float()) / embedding_dim) + m_theta = positions * theta + + cos_values = torch.cos(m_theta) + sin_values = torch.sin(m_theta) + + R[:, 2 * slice_i, 2 * slice_i] = cos_values + R[:, 2 * slice_i, 2 * slice_i + 1] = -sin_values + R[:, 2 * slice_i + 1, 2 * slice_i] = sin_values + R[:, 2 * slice_i + 1, 2 * slice_i + 1] = cos_values + return R + def forward( self, x: torch.Tensor, + rotation_matrix: torch.Tensor, mask: Optional[torch.Tensor] = None, cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: @@ -202,6 +199,8 @@ def forward( ---------- x: torch.Tensor The input tensor. + rotation_matrix: torch.Tensor + Rotation matrix used for positional encoding. mask: torch.Tensor Causal mask. cache: (key_cache, value_cache): (torch.Tensor, torch.Tensor) @@ -215,19 +214,12 @@ def forward( (keys, values): cache for keys and values """ B, L, D = x.shape - queries, keys, values = self.wq(x), self.wk(x), self.wv(x) # Prepare the queries, keys and values for the attention computation. - queries = queries.reshape( - B, L, self.n_heads, -1 - ).transpose(1, 2) - keys = keys.reshape( - B, L, self.n_kv_heads, -1 - ).transpose(1, 2) - values = values.reshape( - B, L, self.n_kv_heads, -1 - ).transpose(1, 2) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(1, 2) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(1, 2) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(1, 2) def repeat(a): a = torch.concat([torch.unsqueeze(a, 2)] * self.repeats, dim=2) @@ -237,25 +229,16 @@ def repeat(a): if cache is not None: key_cache, value_cache = cache - R_matrix = get_rotary_matrix2( - key_cache.shape[2], self.args.head_dim - ) - R_matrix = R_matrix.to("mps") - queries = torch.einsum("bhlj,lij->bhli", [queries, R_matrix]) - keys = torch.einsum("bhlj,lij->bhli", [keys, R_matrix]) + queries = torch.einsum("bhlj,lij->bhli", [queries, rotation_matrix]) + keys = torch.einsum("bhlj,lij->bhli", [keys, rotation_matrix]) keys = torch.concat([key_cache, keys], dim=2) values = torch.concat([value_cache, values], dim=2) else: - R_matrix = get_rotary_matrix1( - keys.shape[2], self.args.head_dim - ) - R_matrix = R_matrix.to("mps") - - queries = torch.einsum("bhlj,lij->bhli", [queries, R_matrix]) - keys = torch.einsum("bhlj,lij->bhli", [keys, R_matrix]) + queries = torch.einsum("bhlj,lij->bhli", [queries, rotation_matrix]) + keys = torch.einsum("bhlj,lij->bhli", [keys, rotation_matrix]) scores = torch.matmul(queries, keys.transpose(2, 3)) * self.scale if mask is not None: @@ -264,7 +247,7 @@ def repeat(a): scores.type(torch.float32), dim=-1 ).type_as(scores) - output = torch.matmul(scores, values) # (B, n_local_heads, L, head_dim) + output = torch.matmul(scores, values) output = output.transpose(1, 2).contiguous().reshape(B, L, -1) return self.wo(output), (keys, values) @@ -276,11 +259,11 @@ class FeedForward(torch.nn.Module): Parameters ---------- - args: ModelArgs + args: TransformerArgs Model parameters. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.w1 = torch.nn.Linear(args.dim, args.hidden_dim, bias=False) @@ -310,11 +293,11 @@ class TransformerBlock(torch.nn.Module): Parameters ---------- - args: ModelArgs + args: TransformerArgs Model parameters. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim @@ -327,6 +310,7 @@ def __init__(self, args: ModelArgs): def forward( self, x: torch.Tensor, + rotation_matrix: torch.Tensor, mask: Optional[torch.Tensor] = None, cache: Optional[ Tuple[torch.Tensor, @@ -340,6 +324,8 @@ def forward( ---------- x: torch.Tensor The input tensor. + rotation_matrix: torch.Tensor + Rotation matrix used for positional encoding. mask: torch.Tensor Causal mask. cache: (key_cache, value_cache): (torch.Tensor, torch.Tensor) @@ -352,24 +338,29 @@ def forward( output: the output tensor (keys, values): cache for keys and values """ - r, cache = self.attention(self.attention_norm(x), mask, cache) + r, cache = self.attention( + self.attention_norm(x), + rotation_matrix=rotation_matrix, + mask=mask, + cache=cache, + ) h = x + r r = self.feed_forward(self.ffn_norm(h)) out = h + r return out, cache -class LLM(torch.nn.Module): +class Transformer(torch.nn.Module): """ - Large Language Model module. + Transformer model. Parameters ---------- - args: ModelArgs + args: TransformerArgs Model parameters. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: TransformerArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size @@ -406,16 +397,36 @@ def forward( """ h = self.tok_embeddings(x) - mask = None + """mask = None if h.shape[1] > 1: mask = Attention.create_additive_causal_mask(h.shape[1]) mask = mask.type(h.dtype) - mask = mask.to("mps") + mask = mask.to(h.device) + + positions = torch.arange( + 1, h.shape[1] + 1, device=h.device + ).unsqueeze(1) + + else: + key_cache = cache[0][0] + positions = torch.tensor( + [key_cache.shape[2] + 1], device=h.device + ).unsqueeze(1) + + rotation_matrix = Attention.create_rotation_matrix( + positions=positions, + embedding_dim=self.args.head_dim, + rope_theta=self.args.rope_theta, + device=h.device, + ) if cache is None: cache = [None] * len(self.layers) for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) + h, cache[e] = layer( + h, rotation_matrix=rotation_matrix, mask=mask, cache=cache[e] + ) - return self.output(self.norm(h)), cache + return self.output(self.norm(h)), cache""" + return h, cache diff --git a/Tests/GrAIExamples/Base/python_lib/llm/tokenizer.py b/Tests/GrAIExamples/Base/python_lib/nlp/tokenizer.py similarity index 100% rename from Tests/GrAIExamples/Base/python_lib/llm/tokenizer.py rename to Tests/GrAIExamples/Base/python_lib/nlp/tokenizer.py diff --git a/Tests/GrAIExamples/Base/python_lib/weight.py b/Tests/GrAIExamples/Base/python_lib/weight.py index 9b9902cf..ae0748a2 100644 --- a/Tests/GrAIExamples/Base/python_lib/weight.py +++ b/Tests/GrAIExamples/Base/python_lib/weight.py @@ -1,12 +1,13 @@ import torch import numpy as np -from typing import List, Tuple +from pathlib import Path +from typing import List, Tuple, Dict from python_lib.model import SimpleAutoEncoder def _flatten_weights( - weights: np.ndarray + weights: np.ndarray ) -> Tuple[np.ndarray, List[int]]: """ Flatten weights and biases. @@ -27,8 +28,38 @@ def _flatten_weights( return weights_list, dims_list +def _extract_weights( + state: Dict[str, torch.Tensor] +) -> Tuple[List[np.ndarray], List[List[int]]]: + """ + Get weights and biases. + + Parameters + ---------- + state: [str: torch.Tensor] + The module state, containing the weights and biases. + + Returns + ------- + (_, _): List[np.ndarray], List[List[int]] + The flattened weights, their shape. + """ + layers_weights: List[np.ndarray] = [] + layers_dims: List[List[int]] = [] + for name, layer_weights in state.items(): + print(f"Extracting weigths {name}.") + weights_list, dims_list = _flatten_weights( + layer_weights.data.cpu().float().numpy() + ) + + layers_weights.append(weights_list) + layers_dims.append(dims_list) + + return layers_weights, layers_dims + + def _extract_and_transpose_weights( - modules: [torch.nn.Module] + modules: [torch.nn.Module] ) -> Tuple[List[np.ndarray], List[List[int]]]: """ Get weights and biases. @@ -94,3 +125,21 @@ def load_simple_auto_encoder_weights( torch.manual_seed(42) model = SimpleAutoEncoder() return _extract_and_transpose_weights(list(model.children())) + + +def load_llm_weights( + model_path: str +) -> Tuple[List[np.ndarray], List[List[int]]]: + """ + Get weights and biases for LLM. + + Returns + ------- + (_, _): List[np.ndarray], List[List[int]] + The flattened weights, their shape. + """ + state = torch.load( + str(Path(model_path) / "consolidated.00.pth"), + map_location="cpu" + ) + return _extract_weights(state) diff --git a/Tests/GrAIExamples/NLPExample.swift b/Tests/GrAIExamples/NLPExample.swift new file mode 100644 index 00000000..a98a709f --- /dev/null +++ b/Tests/GrAIExamples/NLPExample.swift @@ -0,0 +1,125 @@ +// +// NLPExample.swift +// GrAIExamples +// +// Created by Jean-François Reboud on 12/06/2024. +// + +import XCTest +import PythonKit +import GrAIdient + +/// Run generation from prompt. +final class NLPExample: XCTestCase +{ + /// Model path on the disk. + let _modelPath = "TO/UPDATE" + + /// Prompt. + let _prompt = "I" + + /// Initialize test. + override func setUp() + { + setPythonLib() + _ = MetalKernel.get + + GrAI.Opti.GPU = true + GrAI.Precision.float = true + } + + /// + /// Build LLM model. + /// + /// - Parameters: + /// - sequence: Length of the sequence. + /// - hiddenDim: Dimension of neurons in the main branch. + /// - vocabularySize: Vocabulary size. + /// - Returns: The model built. + /// + func _buildModel( + modelPath: String, + sequence: Int, + hiddenDim: Int, + vocabularySize: Int) -> Model + { + let context = ModelContext(name: "NLP", curID: 0) + let params = GrAI.Model.Params(context: context) + + _ = EmbeddingSeq( + sequence: sequence, + vocabularySize: vocabularySize, + nbNeurons: hiddenDim, params: params + ) + + // Retrieve base model in the context and initialize a + // real model (with `layerPrev` links updated). + let model = Model(model: context.model, modelsPrev: []) + + // Load weights from `PyTorch`. + let pythonLib = Python.import("python_lib") + let data = pythonLib.load_llm_weights(modelPath) + var weightsNumpy = [PythonObject](data.tuple2.0)! + + // Apply weights on the `GrAIdient` model's layers. + for num_layer in 0..( + numpy: weightsNumpy.removeFirst() + )! + + layer.weightsCPU = weightsTmp + } + } + return model + } + + /// Generate text from prompt. + func _testGenerate() throws + { + // Encode prompt. + let pythonLib = Python.import("python_lib") + let prompt = [Int](pythonLib.encode( + _prompt, + _modelPath + ))! + + // Compute reference. + let arrayRef = [Float](numpy: pythonLib.generate_main( + _prompt, + _modelPath + ))! + + // Load pre trained model. + let model = _buildModel( + modelPath: _modelPath, + sequence: prompt.count, + hiddenDim: 4096, + vocabularySize: 32000 + ) + + // Initialize for inference. + model.initKernel(phase: .Inference) + model.updateKernel(batchSize: 1) + + // Forward. + let firstLayer: EmbeddingSeq = model.layers.first as! EmbeddingSeq + try! firstLayer.setDataGPU( + [prompt], batchSize: 1, sequence: prompt.count + ) + try! model.forward() + + // Get result. + let arrayOut = (model.layers.last as! LayerSeq).outs.download() + + // Compare difference. + for (elemOut, elemRef) in zip(arrayOut, arrayRef) + { + let diffPercent = abs(elemOut - elemRef) / elemRef * 100.0 + XCTAssert(diffPercent < 0.001) + } + } +} diff --git a/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift b/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift new file mode 100644 index 00000000..3a349b17 --- /dev/null +++ b/Tests/GrAITests/Base/InputSeq/EmbeddingSeqMSE1DCase.swift @@ -0,0 +1,189 @@ +// +// EmbeddingSeqMSE1DCase.swift +// GrAITests +// +// Created by Jean-François Reboud on 11/06/2024. +// + +import XCTest +import GrAIdient +import GrAITestsUtils + +/// +/// A class that will test a model with a structural hypothesis: +/// the model last layer is a MSE1D layer, the model first layer is an EmbeddingSeq. +/// +class EmbeddingSeqMSE1DCase: XCTestCase, Input1DCase, IOCase +{ + /// Batch size of data. + var batchSize: Int = -1 + /// Length of the sequence. + var sequence: Int = -1 + /// Vocabulary size. + var vocabularySize: Int = -1 + /// Optimizer parameters. + var optimizerParams = GrAI.Optimizer.Params() + + /// Systematic call before test begins. + override func setUp() + { + batchSize = 5 + sequence = 7 + vocabularySize = 120 + _ = MetalKernel.get + + GrAI.Opti.GPU = true + GrAI.Precision.float = true + + setOptimizerParams(params: &optimizerParams) + optimizerParams.nbLoops = 3 + } + + /// + /// A function to create/set ground truth to the model. + /// + /// - Parameters: + /// - groundTruth: The ground truth to set. + /// - model: The model. + /// - Returns: The ground truth. + /// + func setLoss(_ groundTruth: [[Double]]?, _ model: Model) -> [[Double]] + { + let lastLayer = model.layers.last as! MSE1D + let gt: [[Double]] + if let groundTruthTmp = groundTruth + { + gt = groundTruthTmp + } + else + { + gt = buildData(dim1: getBatchSize(model), dim2: 1) + } + + if GrAI.Opti.GPU + { + try! lastLayer.lossDerivativeGPU( + gt, batchSize: gt.count, nbNeurons: 1 + ) + } + else + { + try! lastLayer.lossDerivativeCPU( + gt, batchSize: gt.count, nbNeurons: 1 + ) + } + return gt + } + + /// + /// A function to get loss of a model. + /// + /// - Parameters: + /// - groundTruth: The ground truth to set. + /// - model: The model. + /// - Returns: The loss value. + /// + func getLoss(_ groundTruth: [[Double]], _ model: Model) -> Double + { + let lastLayer = model.layers.last as! MSE1D + if GrAI.Opti.GPU + { + return Double(try! lastLayer.getLossGPU( + groundTruth, batchSize: groundTruth.count, nbNeurons: 1 + )) + } + else + { + return try! lastLayer.getLossCPU( + groundTruth, batchSize: groundTruth.count, nbNeurons: 1 + ) + } + } + + /// + /// A function to get the gradients of weights approximations.. + /// + /// - Parameters: + /// - groundTruth: The ground truth. + /// - model: The model. + /// - Returns: The gradients of weights approximations. + /// + func getGradientsApprox( + _ groundTruth: [[Double]], + _ model: Model) -> [Double] + { + let lastLayer = model.layers.last as! MSE1D + return try! lastLayer.collectGradientsApprox( + groundTruth, batchSize: groundTruth.count, nbNeurons: 1 + ) + } + + /// + /// Create synthetic data. + /// + /// - Parameters: + /// - batchSize: Batch size of the data. + /// - sequence: Length of the sequence. + /// - vocabularySize: Vocabulary size. + /// - Returns: The created data. + /// + func buildData( + batchSize: Int, + sequence: Int, + vocabularySize: Int) -> [[Int]] + { + var data = [[Int]]() + for _ in 0.. ([[Int]], Int) + { + let firstLayer = model.layers.first as! EmbeddingSeq + let ins: [[Int]] + if let insTmp = inputs + { + ins = insTmp + } + else + { + ins = buildData( + batchSize: getBatchSize(model), + sequence: sequence, + vocabularySize: vocabularySize + ) + } + + if GrAI.Opti.GPU + { + try! firstLayer.setDataGPU( + ins, batchSize: ins.count, sequence: sequence + ) + } + else + { + try! firstLayer.setDataCPU( + ins, batchSize: ins.count, sequence: sequence + ) + } + return (ins, ins.count) + } +} diff --git a/Tests/GrAITests/NLPTests.swift b/Tests/GrAITests/NLPTests.swift new file mode 100644 index 00000000..ce8710dc --- /dev/null +++ b/Tests/GrAITests/NLPTests.swift @@ -0,0 +1,453 @@ +// +// NLPTests.swift +// GrAITests +// +// Created by Jean-François Reboud on 11/06/2024. +// + +import XCTest +import GrAIdient +import GrAITestsUtils + +// ----------------------------------------------------------------------------- +// Gradient Checking +// We expect to see errors ~ 1e-7 and less. +// ----------------------------------------------------------------------------- +class NLPGradTests: EmbeddingSeqMSE1DCase +{ + override func setUp() + { + super.setUp() + + optimizerParams.nbLoops = 2 + GrAI.Loop.gradientChecking = true + } + + private func _buildTrainer(_ model: String) -> GradTrainer + { + let trainer = GradTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + _buildModel(model: model, context: context) + } + return trainer + } + + private func _buildModel(model: String, context: ModelContext) + { + let params = GrAI.Model.Params(context: context) + + let layer: LayerSeq = EmbeddingSeq( + sequence: sequence, + vocabularySize: vocabularySize, + nbNeurons: 5, params: params + ) + + var head: Layer1D = AvgPoolSeq(layerPrev: layer, params: params) + + head = try! FullyConnected( + layerPrev: head, nbNeurons: 1, + activation: SoftReLU.str, biases: true, params: params + ) + + _ = MSE1D(layerPrev: head, params: params) + } + + func testEmbeddingCPU() throws + { + GrAI.Opti.CPU = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + func testEmbeddingGPU() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + func testEmbeddingSampleGPU() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with CPU ones through time. +// We expect to see errors ~ 1e-7 and less. +// ----------------------------------------------------------------------------- +class NLPFlowTests: EmbeddingSeqMSE1DCase +{ + private func _buildTrainer(_ model: String) -> FlowTrainer + { + let trainer = FlowTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + func buildModel(model: String, context: ModelContext) + { + let params = GrAI.Model.Params(context: context) + + let layer: LayerSeq = EmbeddingSeq( + sequence: sequence, + vocabularySize: vocabularySize, + nbNeurons: 5, params: params + ) + + var head: Layer1D = AvgPoolSeq(layerPrev: layer, params: params) + + head = try! FullyConnected( + layerPrev: head, nbNeurons: 1, + activation: LeakyReLU.str, biases: true, params: params + ) + + _ = MSE1D(layerPrev: head, params: params) + } + + func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with Float precision versus Float16 precision. +// We expect to see errors ~ 1e-4 and less. +// ----------------------------------------------------------------------------- +class NLPFlowPrecisionTests: NLPFlowTests +{ + private func _buildTrainer(_ model: String) -> FlowPrecisionTrainer + { + let trainer = FlowPrecisionTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + override func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with CPU ones through time. +// We expect to see errors ~ 1e-7 and less. +// ----------------------------------------------------------------------------- +class NLPFlowResetTests: NLPFlowTests +{ + override func setUp() + { + super.setUp() + + setOptimizerParams(params: &optimizerParams, + optimizerClass: .Adam) + } + + private func _buildTrainer(_ model: String) -> FlowResetTrainer + { + let trainer = FlowResetTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + override func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with CPU ones through time. +// We expect to see errors ~ 1e-7 and less. +// ----------------------------------------------------------------------------- +class NLPFlowReverseTests: NLPFlowTests +{ + override func setUp() + { + super.setUp() + + setOptimizerParams(params: &optimizerParams, + optimizerClass: .Adam) + } + + private func _buildTrainer(_ model: String) -> FlowReverseTrainer + { + let trainer = FlowReverseTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + override func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU gradients with CPU ones through time. +// We expect to see errors ~ 1e-7 and less. +// ----------------------------------------------------------------------------- +class NLPFlowAccumulateTests: EmbeddingSeqMSE1DCase +{ + private func _buildTrainer(_ model: String) -> FlowTrainer + { + let trainer = FlowAccumulateTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + func buildModel(model: String, context: ModelContext) + { + let params = GrAI.Model.Params(context: context) + + let layer: LayerSeq = EmbeddingSeq( + sequence: sequence, + vocabularySize: vocabularySize, + nbNeurons: 5, params: params + ) + + var head: Layer1D = AvgPoolSeq(layerPrev: layer, params: params) + + head = try! FullyConnected( + layerPrev: head, nbNeurons: 1, + activation: LeakyReLU.str, biases: true, params: params + ) + + _ = MSE1D(layerPrev: head, params: params) + } + + func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU Loss in inference mode with CPU one. +// We expect to see errors ~ 1e-3 and less. +// ----------------------------------------------------------------------------- +class NLPInferenceTests: NLPFlowTests +{ + private func _buildTrainer(_ model: String) -> InferenceTrainer + { + let trainer = InferenceTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + override func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU/CPU Losses in inference mode with the one obtained from a +// loaded model. +// We expect to see errors ~ 1e-3 and less. +// ----------------------------------------------------------------------------- +class NLPLoadTests: NLPFlowTests +{ + private func _buildTrainer(_ model: String) -> LoadTrainer + { + let trainer = LoadTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + override func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +} + +// ----------------------------------------------------------------------------- +// Compare GPU/CPU Losses in inference mode with the one obtained from a +// transformed model. +// We expect to see errors ~ 1e-3 and less. +// ----------------------------------------------------------------------------- +class NLPTransformTests: NLPFlowTests +{ + /// + /// Run Transform tests. + /// + /// The goal is to compare the losses computed in the CPU execution + /// after transforming the model and do the same in the GPU execution context. + /// + /// - Parameters: + /// - trainer: The testing pipeline to run. + /// - nbRetry: The maximum number we can retry the test. + /// - diffThreshold: The threshold above which the relative difference is too high. + /// + func run( + _ trainer: TransformTrainer, + nbRetry: Int = NB_RETRY, + diffThreshold: Double = 0.001) + { + retryNumeric( + nbRetry: nbRetry, + { + () throws in + try trainer.run( + transforms: [self.copy, self.copyInPlace], + setData: self.setData, + setLoss: self.setLoss, + getLoss: self.getLoss) + { + (diffCPU: Double, diffGPU: Double) in + if diffCPU > diffThreshold + { + throw TestError.Numeric + } + if diffGPU > diffThreshold + { + throw TestError.Numeric + } + } + }, + { + () in + XCTAssert(false) + } + ) + } + + private func _buildTrainer(_ model: String) -> TransformTrainer + { + let trainer = TransformTrainer( + name: "NLP", + params: optimizerParams + ) + trainer.build() + { + (context: ModelContext) in + buildModel(model: model, context: context) + } + return trainer + } + + override func testEmbedding() throws + { + let trainer = _buildTrainer("Embedding") + run(trainer) + } + + override func testEmbeddingSample() throws + { + GrAI.Gradient.sample = true + let trainer = _buildTrainer("Embedding") + run(trainer) + } +}